diff --git a/.devops/full-cuda.Dockerfile b/.devops/full-cuda.Dockerfile index 61f671465..b8a354246 100644 --- a/.devops/full-cuda.Dockerfile +++ b/.devops/full-cuda.Dockerfile @@ -1,18 +1,16 @@ ARG UBUNTU_VERSION=22.04 - # This needs to generally match the container host's environment. -ARG CUDA_VERSION=11.7.1 - +ARG CUDA_VERSION=12.6.0 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} FROM ${BASE_CUDA_DEV_CONTAINER} AS build -# Unless otherwise specified, we make a fat build. -ARG CUDA_DOCKER_ARCH=all +# CUDA architecture to build for (defaults to all supported archs) +ARG CUDA_DOCKER_ARCH=default RUN apt-get update && \ - apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1 + apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1 COPY requirements.txt requirements.txt COPY requirements requirements @@ -24,13 +22,12 @@ WORKDIR /app COPY . . -# Set nvcc architecture -ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} -# Enable CUDA -ENV GGML_CUDA=1 -# Enable cURL -ENV LLAMA_CURL=1 - -RUN make -j$(nproc) +# Use the default CUDA archs if not specified +RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release --target llama-cli -j$(nproc) && \ + cp build/bin/* . ENTRYPOINT ["/app/.devops/tools.sh"] diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile new file mode 100644 index 000000000..db5ba2f25 --- /dev/null +++ b/.devops/llama-cli-cann.Dockerfile @@ -0,0 +1,44 @@ +ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8 + +FROM cosdt/cann:$ASCEND_VERSION AS build + +WORKDIR /app + +COPY . . + +RUN yum install -y gcc g++ cmake make +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +# find libascend_hal.so, because the drive hasn`t been mounted. +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH + +RUN echo "Building with static libs" && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ + cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \ + cmake --build build --config Release --target llama-cli + +# TODO: use image with NNRT +FROM cosdt/cann:$ASCEND_VERSION AS runtime +COPY --from=build /app/build/bin/llama-cli /llama-cli + +ENV LC_ALL=C.utf8 + +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +ENTRYPOINT ["/llama-cli" ] diff --git a/.devops/llama-cli-cuda.Dockerfile b/.devops/llama-cli-cuda.Dockerfile index 8eda63a89..b75163b94 100644 --- a/.devops/llama-cli-cuda.Dockerfile +++ b/.devops/llama-cli-cuda.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG CUDA_VERSION=11.7.1 +ARG CUDA_VERSION=12.6.0 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} # Target the CUDA runtime image @@ -8,28 +8,30 @@ ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_V FROM ${BASE_CUDA_DEV_CONTAINER} AS build -# Unless otherwise specified, we make a fat build. -ARG CUDA_DOCKER_ARCH=all +# CUDA architecture to build for (defaults to all supported archs) +ARG CUDA_DOCKER_ARCH=default RUN apt-get update && \ - apt-get install -y build-essential git + apt-get install -y build-essential git cmake WORKDIR /app COPY . . -# Set nvcc architecture -ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} -# Enable CUDA -ENV GGML_CUDA=1 - -RUN make -j$(nproc) llama-cli +# Use the default CUDA archs if not specified +RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_CUDA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release --target llama-cli -j$(nproc) FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime RUN apt-get update && \ apt-get install -y libgomp1 -COPY --from=build /app/llama-cli /llama-cli +COPY --from=build /app/build/ggml/src/libggml.so /libggml.so +COPY --from=build /app/build/src/libllama.so /libllama.so +COPY --from=build /app/build/bin/llama-cli /llama-cli ENTRYPOINT [ "/llama-cli" ] diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile index 67328cf1c..a40e24205 100644 --- a/.devops/llama-server-cuda.Dockerfile +++ b/.devops/llama-server-cuda.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG CUDA_VERSION=11.7.1 +ARG CUDA_VERSION=12.6.0 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} # Target the CUDA runtime image @@ -8,31 +8,34 @@ ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_V FROM ${BASE_CUDA_DEV_CONTAINER} AS build -# Unless otherwise specified, we make a fat build. -ARG CUDA_DOCKER_ARCH=all +# CUDA architecture to build for (defaults to all supported archs) +ARG CUDA_DOCKER_ARCH=default RUN apt-get update && \ - apt-get install -y build-essential git libcurl4-openssl-dev + apt-get install -y build-essential git cmake libcurl4-openssl-dev WORKDIR /app COPY . . -# Set nvcc architecture -ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} -# Enable CUDA -ENV GGML_CUDA=1 -# Enable cURL -ENV LLAMA_CURL=1 - -RUN make -j$(nproc) llama-server +# Use the default CUDA archs if not specified +RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release --target llama-server -j$(nproc) FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime RUN apt-get update && \ apt-get install -y libcurl4-openssl-dev libgomp1 curl -COPY --from=build /app/llama-server /llama-server +COPY --from=build /app/build/ggml/src/libggml.so /libggml.so +COPY --from=build /app/build/src/libllama.so /libllama.so +COPY --from=build /app/build/bin/llama-server /llama-server + +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.devops/llama-server-intel.Dockerfile b/.devops/llama-server-intel.Dockerfile index f525658dd..9c355b664 100644 --- a/.devops/llama-server-intel.Dockerfile +++ b/.devops/llama-server-intel.Dockerfile @@ -26,6 +26,8 @@ RUN apt-get update && \ COPY --from=build /app/build/bin/llama-server /llama-server ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index 763b4cd3f..fd0e19ad6 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -39,6 +39,8 @@ ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang ENV CXX=/opt/rocm/llvm/bin/clang++ +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 # Enable cURL ENV LLAMA_CURL=1 diff --git a/.devops/llama-server-vulkan.Dockerfile b/.devops/llama-server-vulkan.Dockerfile index 13a61ffd8..93c5e0c26 100644 --- a/.devops/llama-server-vulkan.Dockerfile +++ b/.devops/llama-server-vulkan.Dockerfile @@ -23,6 +23,8 @@ RUN cp /app/build/bin/llama-server /llama-server && \ rm -rf /app ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.devops/llama-server.Dockerfile b/.devops/llama-server.Dockerfile index ff558604e..02accc85e 100644 --- a/.devops/llama-server.Dockerfile +++ b/.devops/llama-server.Dockerfile @@ -21,6 +21,8 @@ RUN apt-get update && \ COPY --from=build /app/llama-server /llama-server ENV LC_ALL=C.utf8 +# Must be set to 0.0.0.0 so it can listen to requests from host machine +ENV LLAMA_ARG_HOST=0.0.0.0 HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] diff --git a/.ecrc b/.ecrc index a3351f4e6..c68877ec2 100644 --- a/.ecrc +++ b/.ecrc @@ -1,5 +1,5 @@ { - "Exclude": ["^\\.gitmodules$"], + "Exclude": ["^\\.gitmodules$", "stb_image\\.h"], "Disable": { "IndentSize": true } diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml.disabled similarity index 98% rename from .github/workflows/bench.yml rename to .github/workflows/bench.yml.disabled index 56d22bc0c..bfdbb4ef5 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml.disabled @@ -1,3 +1,6 @@ +# TODO: there have been some issues with the workflow, so disabling for now +# https://github.com/ggerganov/llama.cpp/issues/7893 +# # Benchmark name: Benchmark diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index bf94b2024..56fefd93d 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -96,21 +96,12 @@ jobs: env: GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' - - name: Build and push Docker image (versioned) + - name: Build and push Docker image (tagged + versioned) if: github.event_name == 'push' - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v6 with: context: . push: true platforms: ${{ matrix.config.platforms }} - tags: "ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" - file: ${{ matrix.config.dockerfile }} - - - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v4 - with: - context: . - push: ${{ github.event_name == 'push' }} - platforms: ${{ matrix.config.platforms }} - tags: "ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ steps.tag.outputs.name }}" + tags: "ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ steps.tag.outputs.name }}" file: ${{ matrix.config.dockerfile }} diff --git a/.gitignore b/.gitignore index 5ae030200..9986ac6b1 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,6 @@ poetry.toml # Scripts !/scripts/install-oneapi.bat + +# Test models for lora adapters +/lora-tests diff --git a/CMakePresets.json b/CMakePresets.json index bdad38952..ce627b4d3 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -28,6 +28,7 @@ { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, { "name": "arm64-windows-msvc", "hidden": true, @@ -60,6 +61,8 @@ { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, - { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] } ] } diff --git a/README.md b/README.md index 7f48fde6e..bb2b93a35 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,8 @@ Typically finetunes of the base models below are supported as well. - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) - [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) +- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) +- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) (instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md)) @@ -424,6 +426,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md) | [CUDA](./docs/build.md#cuda) | Nvidia GPU | | [hipBLAS](./docs/build.md#hipblas) | AMD GPU | | [Vulkan](./docs/build.md#vulkan) | GPU | +| [CANN](./docs/build.md#cann) | Ascend NPU | ## Tools diff --git a/ci/run.sh b/ci/run.sh index 58022c7dc..751bb0a02 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -13,6 +13,9 @@ # # with SYCL support # GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with VULKAN support +# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -40,7 +43,7 @@ if [ ! -z ${GG_BUILD_METAL} ]; then fi if [ ! -z ${GG_BUILD_CUDA} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=1" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native" fi if [ ! -z ${GG_BUILD_SYCL} ]; then @@ -52,6 +55,10 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" fi + +if [ ! -z ${GG_BUILD_VULKAN} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1" +fi ## helpers # download a file if it does not exist or if it is outdated @@ -107,7 +114,7 @@ function gg_run_ctest_debug { gg_check_build_requirements (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log (time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log @@ -138,7 +145,7 @@ function gg_run_ctest_release { gg_check_build_requirements (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log if [ -z ${GG_BUILD_LOW_PERF} ]; then (time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log @@ -266,7 +273,6 @@ function gg_sum_ctest_with_model_release { } # open_llama_7b_v2 -# requires: GG_BUILD_CUDA function gg_run_open_llama_7b_v2 { cd ${SRC} @@ -290,8 +296,8 @@ function gg_run_open_llama_7b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DGGML_CUDA=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../examples/convert_legacy_llama.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -425,7 +431,7 @@ function gg_run_pythia_1_4b { set -e (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -535,7 +541,6 @@ function gg_sum_pythia_1_4b { } # pythia_2_8b -# requires: GG_BUILD_CUDA function gg_run_pythia_2_8b { cd ${SRC} @@ -556,8 +561,8 @@ function gg_run_pythia_2_8b { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DGGML_CUDA=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -692,7 +697,7 @@ function gg_run_embd_bge_small { set -e (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf @@ -761,7 +766,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then fi if [ -z ${GG_BUILD_VRAM_GB} ] || [ ${GG_BUILD_VRAM_GB} -ge 8 ]; then - if [ -z ${GG_BUILD_CUDA} ]; then + if [ -z ${GG_BUILD_CUDA} ] && [ -z ${GG_BUILD_VULKAN} ]; then test $ret -eq 0 && gg_run pythia_1_4b else test $ret -eq 0 && gg_run pythia_2_8b diff --git a/common/common.cpp b/common/common.cpp index 874f1f4f8..39db42608 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -77,6 +77,41 @@ using json = nlohmann::ordered_json; +// +// Environment variable utils +// + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::string(value) : target; +} + +template +static typename std::enable_if::value && std::is_integral::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stoi(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stof(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + if (value) { + std::string val(value); + target = val == "1" || val == "true"; + } +} + // // CPU utils // @@ -110,8 +145,34 @@ int32_t cpu_get_num_physical_cores() { if (result == 0) { return num_physical_cores; } -#elif defined(_WIN32) - //TODO: Implement +#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + unsigned int n_threads_win = std::thread::hardware_concurrency(); + unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4; + + DWORD buffer_size = 0; + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) { + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) { + return default_threads; + } + } + + std::vector buffer(buffer_size); + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast(buffer.data()), &buffer_size)) { + return default_threads; + } + + int32_t num_physical_cores = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast(buffer.data()); + while (buffer_size > 0) { + if (info->Relationship == RelationProcessorCore) { + num_physical_cores += info->Processor.GroupCount; + } + buffer_size -= info->Size; + info = reinterpret_cast(reinterpret_cast(info) + info->Size); + } + + return num_physical_cores > 0 ? num_physical_cores : default_threads; #endif unsigned int n_threads = std::thread::hardware_concurrency(); return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; @@ -190,16 +251,61 @@ int32_t cpu_get_num_math() { return cpu_get_num_physical_cores(); } +// Helper for setting process priority + +#if defined(_WIN32) + +bool set_process_priority(enum ggml_sched_priority prio) { + if (prio == GGML_SCHED_PRIO_NORMAL) { + return true; + } + + DWORD p = NORMAL_PRIORITY_CLASS; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; + case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break; + } + + if (!SetPriorityClass(GetCurrentProcess(), p)) { + fprintf(stderr, "warn: failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#else // MacOS and POSIX +#include +#include + +bool set_process_priority(enum ggml_sched_priority prio) { + if (prio == GGML_SCHED_PRIO_NORMAL) { + return true; + } + + int p = 0; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: p = 0; break; + case GGML_SCHED_PRIO_MEDIUM: p = -5; break; + case GGML_SCHED_PRIO_HIGH: p = -10; break; + case GGML_SCHED_PRIO_REALTIME: p = -20; break; + } + + if (!setpriority(PRIO_PROCESS, 0, p)) { + fprintf(stderr, "warn: failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); + return false; + } + return true; +} + +#endif + // // CLI argument parsing // -void gpt_params_handle_hf_token(gpt_params & params) { - if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { - params.hf_token = std::getenv("HF_TOKEN"); - } -} - void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -222,6 +328,30 @@ void gpt_params_handle_model_default(gpt_params & params) { } } +void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) { + int32_t n_set = 0; + + if (cpuparams.n_threads < 0) { + // Assuming everything about cpuparams is invalid + if (role_model != nullptr) { + cpuparams = *role_model; + } else { + cpuparams.n_threads = cpu_get_num_math(); + } + } + + for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + if (cpuparams.cpumask[i]) { + n_set++; + } + } + + if (n_set && n_set < cpuparams.n_threads) { + // Not enough set bits, may experience performance issues. + fprintf(stderr, "warn: Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); + } +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -241,13 +371,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } + postprocess_cpu_params(params.cpuparams, nullptr); + postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); + postprocess_cpu_params(params.draft_cpuparams, ¶ms.cpuparams); + postprocess_cpu_params(params.draft_cpuparams_batch, ¶ms.cpuparams_batch); + if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } gpt_params_handle_model_default(params); - gpt_params_handle_hf_token(params); + if (params.hf_token.empty()) { + get_env("HF_TOKEN", params.hf_token); + } if (params.escape) { string_process_escapes(params.prompt); @@ -267,6 +404,32 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { return true; } +void gpt_params_parse_from_env(gpt_params & params) { + // we only care about server-related params for now + get_env("LLAMA_ARG_MODEL", params.model); + get_env("LLAMA_ARG_MODEL_URL", params.model_url); + get_env("LLAMA_ARG_MODEL_ALIAS", params.model_alias); + get_env("LLAMA_ARG_HF_REPO", params.hf_repo); + get_env("LLAMA_ARG_HF_FILE", params.hf_file); + get_env("LLAMA_ARG_THREADS", params.cpuparams.n_threads); + get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx); + get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel); + get_env("LLAMA_ARG_BATCH", params.n_batch); + get_env("LLAMA_ARG_UBATCH", params.n_ubatch); + get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers); + get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http); + get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template); + get_env("LLAMA_ARG_N_PREDICT", params.n_predict); + get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics); + get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots); + get_env("LLAMA_ARG_EMBEDDINGS", params.embedding); + get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn); + get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold); + get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching); + get_env("LLAMA_ARG_HOST", params.hostname); + get_env("LLAMA_ARG_PORT", params.port); +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { const auto params_org = params; // the example can modify the default params @@ -285,6 +448,79 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } +bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) { + size_t dash_loc = range.find('-'); + if (dash_loc == std::string::npos) { + fprintf(stderr, "Format of CPU range is invalid! Expected []-[].\n"); + return false; + } + + size_t start_i; + size_t end_i; + + if (dash_loc == 0) { + start_i = 0; + } else { + start_i = std::stoull(range.substr(0, dash_loc)); + if (start_i >= GGML_MAX_N_THREADS) { + fprintf(stderr, "Start index out of bounds!\n"); + return false; + } + } + + if (dash_loc == range.length() - 1) { + end_i = GGML_MAX_N_THREADS - 1; + } else { + end_i = std::stoull(range.substr(dash_loc + 1)); + if (end_i >= GGML_MAX_N_THREADS) { + fprintf(stderr, "End index out of bounds!\n"); + return false; + } + } + + for (size_t i = start_i; i <= end_i; i++) { + boolmask[i] = true; + } + + return true; +} + +bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) { + // Discard potential 0x prefix + size_t start_i = 0; + if (mask.length() >= 2 && mask.substr(0, 2) == "0x") { + start_i = 2; + } + + size_t num_digits = mask.length() - start_i; + if (num_digits > 128) num_digits = 128; + + size_t end_i = num_digits + start_i; + + for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) { + char c = mask.at(i); + int8_t id = c; + + if ((c >= '0' && c <= '9')) { + id -= '0'; + } else if (c >= 'a' && c <= 'f') { + id -= 'a' - 10; + } else if (c >= 'A' && c <= 'F') { + id -= 'A' - 10; + } else { + fprintf(stderr, "Invalid hex character '%c' at position %d\n", c, int32_t(i)); + return false; + } + + boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0); + boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0); + boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0); + boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0); + } + + return true; +} + #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { @@ -301,36 +537,142 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "-t" || arg == "--threads") { CHECK_ARG - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); + params.cpuparams.n_threads = std::stoi(argv[i]); + if (params.cpuparams.n_threads <= 0) { + params.cpuparams.n_threads = std::thread::hardware_concurrency(); } return true; } + if (arg == "-C" || arg == "--cpu-mask") { + CHECK_ARG + std::string mask = argv[i]; + params.cpuparams.mask_valid = true; + invalid_param = !parse_cpu_mask(mask, params.cpuparams.cpumask); + return true; + } + if (arg == "-Cr" || arg == "--cpu-range") { + CHECK_ARG + std::string range = argv[i]; + params.cpuparams.mask_valid = true; + invalid_param = !parse_cpu_range(range, params.cpuparams.cpumask); + return true; + } + if (arg == "--prio") { + CHECK_ARG + params.cpuparams.priority = (enum ggml_sched_priority) std::stoul(argv[i]); + return true; + } + if (arg == "--cpu-strict") { + CHECK_ARG + params.cpuparams.strict_cpu = std::stoul(argv[i]); + return true; + } + if (arg == "--poll") { + CHECK_ARG + params.cpuparams.poll = std::stoul(argv[i]); + return true; + } if (arg == "-tb" || arg == "--threads-batch") { CHECK_ARG - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); + params.cpuparams_batch.n_threads = std::stoi(argv[i]); + if (params.cpuparams_batch.n_threads <= 0) { + params.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); } return true; } + if (arg == "-Cb" || arg == "--cpu-mask-batch") { + CHECK_ARG + std::string mask = argv[i]; + params.cpuparams_batch.mask_valid = true; + invalid_param = !parse_cpu_mask(mask, params.cpuparams_batch.cpumask); + return true; + } + if (arg == "-Crb" || arg == "--cpu-range_batch") { + CHECK_ARG + std::string range = argv[i]; + params.cpuparams_batch.mask_valid = true; + invalid_param = !parse_cpu_range(range, params.cpuparams_batch.cpumask); + return true; + } + if (arg == "--prio-batch") { + CHECK_ARG + params.cpuparams_batch.priority = (enum ggml_sched_priority) std::stoul(argv[i]); + return true; + } + if (arg == "--cpu-strict-batch") { + params.cpuparams_batch.strict_cpu = true; + return true; + } + if (arg == "--poll-batch") { + CHECK_ARG + params.cpuparams_batch.poll = std::stoul(argv[i]); + return true; + } if (arg == "-td" || arg == "--threads-draft") { CHECK_ARG - params.n_threads_draft = std::stoi(argv[i]); - if (params.n_threads_draft <= 0) { - params.n_threads_draft = std::thread::hardware_concurrency(); + params.draft_cpuparams.n_threads = std::stoi(argv[i]); + if (params.draft_cpuparams.n_threads <= 0) { + params.draft_cpuparams.n_threads = std::thread::hardware_concurrency(); } return true; + } + if (arg == "-Cd" || arg == "--cpu-mask-draft") { + CHECK_ARG + std::string mask = argv[i]; + params.draft_cpuparams.mask_valid = true; + invalid_param = !parse_cpu_mask(mask, params.draft_cpuparams.cpumask); + return true; + } + if (arg == "-Crd" || arg == "--cpu-range-draft") { + CHECK_ARG + std::string range = argv[i]; + params.draft_cpuparams.mask_valid = true; + invalid_param = !parse_cpu_range(range, params.draft_cpuparams.cpumask); + return true; + } + if (arg == "--prio-draft") { + CHECK_ARG + params.draft_cpuparams.priority = (enum ggml_sched_priority) std::stoul(argv[i]); + return true; + } + if (arg == "--cpu-strict-draft") { + params.draft_cpuparams.strict_cpu = true; + return true; + } + if (arg == "--poll-draft") { + CHECK_ARG + params.draft_cpuparams.poll = std::stoul(argv[i]); + return true; } if (arg == "-tbd" || arg == "--threads-batch-draft") { CHECK_ARG - params.n_threads_batch_draft = std::stoi(argv[i]); - if (params.n_threads_batch_draft <= 0) { - params.n_threads_batch_draft = std::thread::hardware_concurrency(); + params.draft_cpuparams_batch.n_threads = std::stoi(argv[i]); + if (params.draft_cpuparams_batch.n_threads <= 0) { + params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency(); } return true; } + if (arg == "-Crbd" || arg == "--cpu-range-batch-draft") { + CHECK_ARG + std::string range = argv[i]; + params.draft_cpuparams_batch.mask_valid = true; + invalid_param = !parse_cpu_range(range, params.draft_cpuparams_batch.cpumask); + return true; + } + if (arg == "--prio-batch-draft") { + CHECK_ARG + params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) std::stoul(argv[i]); + return true; + } + if (arg == "--cpu-strict-batch-draft") { + params.draft_cpuparams_batch.strict_cpu = true; + return true; + } + if (arg == "--poll-batch-draft") { + CHECK_ARG + params.draft_cpuparams_batch.poll = std::stoul(argv[i]); + return true; + } if (arg == "-p" || arg == "--prompt") { CHECK_ARG params.prompt = argv[i]; @@ -830,7 +1172,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } - if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { + if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") { CHECK_ARG params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { @@ -1420,11 +1762,40 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); - options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); + options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads }); options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); - options.push_back({ "speculative", "-tbd, --threads-batch-draft N", - "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); + options.push_back({ "speculative", "-tbd, --threads-batch-draft N","number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); + +#ifndef GGML_USE_OPENMP + // these options are available only with the internal threadpool + options.push_back({ "*", "-C, --cpu-mask M", "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")"}); + options.push_back({ "*", "-Cr, --cpu-range lo-hi", "range of CPUs for affinity. Complements --cpu-mask"}); + options.push_back({ "*", " --cpu-strict <0|1>", "use strict CPU placement (default: %u)\n", (unsigned) params.cpuparams.strict_cpu}); + options.push_back({ "*", " --priority N", "set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority}); + options.push_back({ "*", " --poll <0...100>", "use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll}); + + options.push_back({ "*", "-Cb, --cpu-mask-batch M", "CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)"}); + options.push_back({ "*", "-Crb, --cpu-range-batch lo-hi", "ranges of CPUs for affinity. Complements --cpu-mask-batch"}); + options.push_back({ "*", " --cpu-strict-batch <0|1>","use strict CPU placement (default: same as --cpu-strict)"}); + options.push_back({ "*", " --priority-batch N", "set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: --priority)"}); + options.push_back({ "*", " --poll-batch <0|1>", "use polling to wait for work (default: same as --poll"}); + + options.push_back({ "speculative", "-Cd, --cpu-mask-draft M", "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)"}); + options.push_back({ "speculative", "-Crd, --cpu-range-draft lo-hi", "Ranges of CPUs for affinity. Complements --cpu-mask-draft"}); + options.push_back({ "speculative", " --cpu-strict-draft <0|1>","Use strict CPU placement for draft model (default: same as --cpu-strict)"}); + options.push_back({ "speculative", " --priority-draft N", "Set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: same as --priority)"}); + options.push_back({ "speculative", " --poll-draft <0|1>", "Use polling to wait for draft model work (default: same as --poll])"}); + + options.push_back({ "speculative", "-Cbd, --cpu-mask-batch-draft M","Draft model CPU affinity mask. Complements cpu-range-draft-batch (default: same as --cpu-mask-draft)"}); + options.push_back({ "speculative", "-Crbd, --cpu-range-batch-draft lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)"}); + options.push_back({ "speculative", " --cpu-strict-batch-draft <0|1>", + "Use strict CPU placement for draft model (default: --cpu-strict-draft)"}); + options.push_back({ "speculative", " --priority-batch-draft N","Set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: --priority-draft)"}); + options.push_back({ "speculative", " --poll-batch-draft <0|1>","Use polling to wait for draft model work (default: --poll-draft)"}); +#endif // GGML_USE_OPENMP + options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", @@ -1698,7 +2069,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "export-lora", "-m, --model", "model path from which to load base model (default '%s')", params.model.c_str() }); options.push_back({ "export-lora", " --lora FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)" }); options.push_back({ "export-lora", " --lora-scaled FNAME S", "path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" }); - options.push_back({ "*", "-t, --threads N", "number of threads to use during computation (default: %d)", params.n_threads }); options.push_back({ "export-lora", "-o, --output FNAME", "output file (default: '%s')", params.lora_outfile.c_str() }); printf("usage: %s [options]\n", argv[0]); @@ -1730,11 +2100,17 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param std::string gpt_params_get_system_info(const gpt_params & params) { std::ostringstream os; - os << "system_info: n_threads = " << params.n_threads; - if (params.n_threads_batch != -1) { - os << " (n_threads_batch = " << params.n_threads_batch << ")"; + os << "system_info: n_threads = " << params.cpuparams.n_threads; + if (params.cpuparams_batch.n_threads != -1) { + os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")"; } +#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later + // TODO: windows + arm64 + mingw64 + DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS); + os << " / " << logicalProcessorCount << " | " << llama_print_system_info(); +#else os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); +#endif return os.str(); } @@ -1786,13 +2162,19 @@ std::string string_get_sortable_timestamp() { void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } void string_process_escapes(std::string & input) { @@ -2244,8 +2626,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_seq_max = params.n_parallel; cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.n_threads = params.cpuparams.n_threads; + cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? + params.cpuparams.n_threads : params.cpuparams_batch.n_threads; cparams.seed = params.seed; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; @@ -2271,6 +2654,22 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } +struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) { + struct ggml_threadpool_params tpp; + + ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults + + if (params.mask_valid) { + std::memcpy(&tpp.cpumask, ¶ms.cpumask, GGML_MAX_N_THREADS); + } + + tpp.prio = params.priority; + tpp.poll = params.poll; + tpp.strict_cpu = params.strict_cpu; + + return tpp; +} + #ifdef LLAMA_USE_CURL static bool starts_with(const std::string & str, const std::string & prefix) { @@ -2709,12 +3108,6 @@ std::string llama_detokenize(llama_context * ctx, const std::vector return text; } -bool llama_should_add_bos_token(const llama_model * model) { - const int add_bos = llama_add_bos_token(model); - - return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); -} - // // Chat template utils // @@ -3266,7 +3659,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); diff --git a/common/common.h b/common/common.h index acd6e4d92..a14c0f448 100644 --- a/common/common.h +++ b/common/common.h @@ -67,13 +67,18 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; +struct cpu_params { + int n_threads = -1; + bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. + bool mask_valid = false; // Default: any CPU + enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) + bool strict_cpu = false; // Use strict CPU placement + uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = cpu_get_num_math(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) @@ -100,6 +105,11 @@ struct gpt_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + struct cpu_params draft_cpuparams; + struct cpu_params draft_cpuparams_batch; + ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; @@ -205,7 +215,7 @@ struct gpt_params { int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests + int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) std::string hostname = "127.0.0.1"; std::string public_path = ""; @@ -268,7 +278,7 @@ struct gpt_params { std::string lora_outfile = "ggml-lora-merged-f16.gguf"; }; -void gpt_params_handle_hf_token(gpt_params & params); +void gpt_params_parse_from_env(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); @@ -278,6 +288,11 @@ void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params); std::string gpt_params_get_system_info(const gpt_params & params); +bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]); +bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]); +void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr); +bool set_process_priority(enum ggml_sched_priority prio); + // // String utils // @@ -328,8 +343,9 @@ struct llama_init_result { struct llama_init_result llama_init_from_gpt_params(gpt_params & params); -struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); +struct llama_context_params llama_context_params_from_gpt_params (const gpt_params & params); +struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); @@ -381,10 +397,6 @@ std::string llama_detokenize( const std::vector & tokens, bool special = true); -// Uses the value from the model metadata if possible, otherwise -// defaults to true when model type is SPM, otherwise false. -bool llama_should_add_bos_token(const llama_model * model); - // // Chat template utils // diff --git a/common/stb_image.h b/common/stb_image.h index 4766d7e67..9eedabedc 100644 --- a/common/stb_image.h +++ b/common/stb_image.h @@ -1,4 +1,4 @@ -/* stb_image - v2.28 - public domain image loader - http://nothings.org/stb +/* stb_image - v2.30 - public domain image loader - http://nothings.org/stb no warranty implied; use at your own risk Do this: @@ -48,6 +48,8 @@ LICENSE RECENT REVISION HISTORY: + 2.30 (2024-05-31) avoid erroneous gcc warning + 2.29 (2023-05-xx) optimizations 2.28 (2023-01-29) many error fixes, security errors, just tons of stuff 2.27 (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes 2.26 (2020-07-13) many minor fixes @@ -371,13 +373,14 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum { - STBI_default = 0, // only used for desired_channels +enum +{ + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -405,11 +408,11 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct { - int (*read)(void * user, char * data, - int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip)(void * user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof)(void * user); // returns nonzero if we are at end of file/data +typedef struct +{ + int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof) (void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -417,24 +420,21 @@ typedef struct { // 8-bits-per-channel interface // -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); // for stbi_load_from_file, file pointer is left pointing immediately after image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); #endif //////////////////////////////////// @@ -442,14 +442,12 @@ STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wc // 16-bits-per-channel interface // -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); #endif //////////////////////////////////// @@ -457,55 +455,56 @@ STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * chann // float-per-channel interface // #ifndef STBI_NO_LINEAR -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels); -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * channels_in_file, - int desired_channels); + STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); -#ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); -#endif + #ifndef STBI_NO_STDIO + STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); + STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); + #endif #endif #ifndef STBI_NO_HDR -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); -STBIDEF void stbi_hdr_to_ldr_scale(float scale); + STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); + STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); -STBIDEF void stbi_ldr_to_hdr_scale(float scale); + STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); + STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename); -STBIDEF int stbi_is_hdr_from_file(FILE * f); +STBIDEF int stbi_is_hdr (char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO + // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char * stbi_failure_reason(void); +STBIDEF const char *stbi_failure_reason (void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free(void * retval_from_stbi_load); +STBIDEF void stbi_image_free (void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * clbk, void * user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp); -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp); -STBIDEF int stbi_is_16_bit(char const * filename); -STBIDEF int stbi_is_16_bit_from_file(FILE * f); +STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit (char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif + + // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. @@ -527,14 +526,14 @@ STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_fli // ZLIB client - used by PNG, available for other purposes -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen); -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header); -STBIDEF char * stbi_zlib_decode_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); -STBIDEF char * stbi_zlib_decode_noheader_malloc(const char * buffer, int len, int * outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); #ifdef __cplusplus } @@ -547,50 +546,52 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ - defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ - defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) -#ifndef STBI_ONLY_JPEG -#define STBI_NO_JPEG -#endif -#ifndef STBI_ONLY_PNG -#define STBI_NO_PNG -#endif -#ifndef STBI_ONLY_BMP -#define STBI_NO_BMP -#endif -#ifndef STBI_ONLY_PSD -#define STBI_NO_PSD -#endif -#ifndef STBI_ONLY_TGA -#define STBI_NO_TGA -#endif -#ifndef STBI_ONLY_GIF -#define STBI_NO_GIF -#endif -#ifndef STBI_ONLY_HDR -#define STBI_NO_HDR -#endif -#ifndef STBI_ONLY_PIC -#define STBI_NO_PIC -#endif -#ifndef STBI_ONLY_PNM -#define STBI_NO_PNM -#endif +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ + || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ + || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ + || defined(STBI_ONLY_ZLIB) + #ifndef STBI_ONLY_JPEG + #define STBI_NO_JPEG + #endif + #ifndef STBI_ONLY_PNG + #define STBI_NO_PNG + #endif + #ifndef STBI_ONLY_BMP + #define STBI_NO_BMP + #endif + #ifndef STBI_ONLY_PSD + #define STBI_NO_PSD + #endif + #ifndef STBI_ONLY_TGA + #define STBI_NO_TGA + #endif + #ifndef STBI_ONLY_GIF + #define STBI_NO_GIF + #endif + #ifndef STBI_ONLY_HDR + #define STBI_NO_HDR + #endif + #ifndef STBI_ONLY_PIC + #define STBI_NO_PIC + #endif + #ifndef STBI_ONLY_PNM + #define STBI_NO_PNM + #endif #endif #if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) #define STBI_NO_ZLIB #endif -#include + #include #include // ptrdiff_t on osx #include #include +#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -608,54 +609,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha #define STBI_EXTERN extern #endif + #ifndef _MSC_VER -#ifdef __cplusplus -#define stbi_inline inline + #ifdef __cplusplus + #define stbi_inline inline + #else + #define stbi_inline + #endif #else -#define stbi_inline -#endif -#else -#define stbi_inline __forceinline + #define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS -#if defined(__cplusplus) && __cplusplus >= 201103L -#define STBI_THREAD_LOCAL thread_local -#elif defined(__GNUC__) && __GNUC__ < 5 -#define STBI_THREAD_LOCAL __thread -#elif defined(_MSC_VER) -#define STBI_THREAD_LOCAL __declspec(thread) -#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) -#define STBI_THREAD_LOCAL _Thread_local -#endif + #if defined(__cplusplus) && __cplusplus >= 201103L + #define STBI_THREAD_LOCAL thread_local + #elif defined(__GNUC__) && __GNUC__ < 5 + #define STBI_THREAD_LOCAL __thread + #elif defined(_MSC_VER) + #define STBI_THREAD_LOCAL __declspec(thread) + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) + #define STBI_THREAD_LOCAL _Thread_local + #endif -#ifndef STBI_THREAD_LOCAL -#if defined(__GNUC__) -#define STBI_THREAD_LOCAL __thread -#endif -#endif + #ifndef STBI_THREAD_LOCAL + #if defined(__GNUC__) + #define STBI_THREAD_LOCAL __thread + #endif + #endif #endif #if defined(_MSC_VER) || defined(__SYMBIAN32__) typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -663,9 +665,9 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL -#define stbi_lrot(x, y) _lrotl(x, y) + #define stbi_lrot(x,y) _lrotl(x,y) #else -#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (-(y)&31))) + #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (-(y) & 31))) #endif #if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) @@ -677,13 +679,13 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p, newsz) realloc(p, newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p,newsz) realloc(p,newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) +#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) #endif // x86/x64 detection @@ -725,31 +727,34 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) { - int info[4]; - __cpuid(info, 1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) +{ + int info[4]; + __cpuid(info,1); + return info[3]; } #else -static int stbi__cpuid3(void) { - int res; - __asm { +static int stbi__cpuid3(void) +{ + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) { - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) +{ + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -757,11 +762,12 @@ static int stbi__sse2_available(void) { #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) { - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) +{ + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -796,162 +802,190 @@ static int stbi__sse2_available(void) { // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct { - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct +{ + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void * io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; -static void stbi__refill_buffer(stbi__context * s); + +static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context * s, stbi_uc const * buffer, int len) { - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) +{ + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context * s, stbi_io_callbacks * c, void * user) { - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) +{ + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void * user, char * data, int size) { return (int)fread(data, 1, size, (FILE *)user); } - -static void stbi__stdio_skip(void * user, int n) { - int ch; - fseek((FILE *)user, n, SEEK_CUR); - ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ - } +static int stbi__stdio_read(void *user, char *data, int size) +{ + return (int) fread(data,1,size,(FILE*) user); } -static int stbi__stdio_eof(void * user) { return feof((FILE *)user) || ferror((FILE *)user); } +static void stbi__stdio_skip(void *user, int n) +{ + int ch; + fseek((FILE*) user, n, SEEK_CUR); + ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ + } +} -static stbi_io_callbacks stbi__stdio_callbacks = { - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static int stbi__stdio_eof(void *user) +{ + return feof((FILE*) user) || ferror((FILE *) user); +} + +static stbi_io_callbacks stbi__stdio_callbacks = +{ + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context * s, FILE * f) { stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } +static void stbi__start_file(stbi__context *s, FILE *f) +{ + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +} -// static void stop_file(stbi__context *s) { } +//static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context * s) { - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) +{ + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; +enum +{ + STBI_ORDER_RGB, + STBI_ORDER_BGR +}; -typedef struct { - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct +{ + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context * s); -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context * s); -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__png_is16(stbi__context * s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context * s); -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context * s); -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s); -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc); -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__psd_is16(stbi__context * s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context * s); -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context * s); -static void * stbi__pic_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context * s); -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp); -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s); -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp); -static int stbi__pnm_is16(stbi__context * s); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_is16(stbi__context *s); #endif static #ifdef STBI_THREAD_LOCAL - STBI_THREAD_LOCAL +STBI_THREAD_LOCAL #endif - const char * stbi__g_failure_reason; +const char *stbi__g_failure_reason; -STBIDEF const char * stbi_failure_reason(void) { return stbi__g_failure_reason; } +STBIDEF const char *stbi_failure_reason(void) +{ + return stbi__g_failure_reason; +} #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char * str) { - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) +{ + stbi__g_failure_reason = str; + return 0; } #endif -static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } +static void *stbi__malloc(size_t size) +{ + return STBI_MALLOC(size); +} // stb_image uses ints pervasively, including for offset calculations. // therefore the largest decoded image size we can support with the @@ -965,88 +999,88 @@ static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) { - if (b < 0) - return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) +{ + if (b < 0) return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) { - if (a < 0 || b < 0) - return 0; - if (b == 0) - return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX / b; +static int stbi__mul2sizes_valid(int a, int b) +{ + if (a < 0 || b < 0) return 0; + if (b == 0) return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX/b; } #if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__addsizes_valid(a * b * c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__addsizes_valid(a*b*c, add); } // returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__mul2sizes_valid(a * b * c, d) && - stbi__addsizes_valid(a * b * c * d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); } #endif #if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void * stbi__malloc_mad2(int a, int b, int add) { - if (!stbi__mad2sizes_valid(a, b, add)) - return NULL; - return stbi__malloc(a * b + add); +static void *stbi__malloc_mad2(int a, int b, int add) +{ + if (!stbi__mad2sizes_valid(a, b, add)) return NULL; + return stbi__malloc(a*b + add); } #endif -static void * stbi__malloc_mad3(int a, int b, int c, int add) { - if (!stbi__mad3sizes_valid(a, b, c, add)) - return NULL; - return stbi__malloc(a * b * c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) +{ + if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; + return stbi__malloc(a*b*c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) -static void * stbi__malloc_mad4(int a, int b, int c, int d, int add) { - if (!stbi__mad4sizes_valid(a, b, c, d, add)) - return NULL; - return stbi__malloc(a * b * c * d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) +{ + if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; + return stbi__malloc(a*b*c*d + add); } #endif // returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. -static int stbi__addints_valid(int a, int b) { - if ((a >= 0) != (b >= 0)) - return 1; // a and b have different signs, so no overflow - if (a < 0 && b < 0) - return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. - return a <= INT_MAX - b; +static int stbi__addints_valid(int a, int b) +{ + if ((a >= 0) != (b >= 0)) return 1; // a and b have different signs, so no overflow + if (a < 0 && b < 0) return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a <= INT_MAX - b; } -// returns 1 if the product of two signed shorts is valid, 0 on overflow. -static int stbi__mul2shorts_valid(short a, short b) { - if (b == 0 || b == -1) - return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow - if ((a >= 0) == (b >= 0)) - return a <= SHRT_MAX / b; // product is positive, so similar to mul2sizes_valid - if (b < 0) - return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN - return a >= SHRT_MIN / b; +// returns 1 if the product of two ints fits in a signed short, 0 on overflow. +static int stbi__mul2shorts_valid(int a, int b) +{ + if (b == 0 || b == -1) return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + if ((a >= 0) == (b >= 0)) return a <= SHRT_MAX/b; // product is positive, so similar to mul2sizes_valid + if (b < 0) return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a >= SHRT_MIN / b; } // stbi__err - error @@ -1054,411 +1088,423 @@ static int stbi__mul2shorts_valid(short a, short b) { // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS -#define stbi__err(x, y) 0 + #define stbi__err(x,y) 0 #elif defined(STBI_FAILURE_USERMSG) -#define stbi__err(x, y) stbi__err(y) + #define stbi__err(x,y) stbi__err(y) #else -#define stbi__err(x, y) stbi__err(x) + #define stbi__err(x,y) stbi__err(x) #endif -#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -#define stbi__errpuc(x, y) ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) -STBIDEF void stbi_image_free(void * retval_from_stbi_load) { STBI_FREE(retval_from_stbi_load); } +STBIDEF void stbi_image_free(void *retval_from_stbi_load) +{ + STBI_FREE(retval_from_stbi_load); +} #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load \ - (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ + ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order + ri->num_channels = 0; -// test the formats with a very explicit header first (at least a FOURCC -// or distinctive magic number first) -#ifndef STBI_NO_PNG - if (stbi__png_test(s)) - return stbi__png_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) - return stbi__bmp_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_GIF - if (stbi__gif_test(s)) - return stbi__gif_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_PSD - if (stbi__psd_test(s)) - return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); -#else - STBI_NOTUSED(bpc); -#endif -#ifndef STBI_NO_PIC - if (stbi__pic_test(s)) - return stbi__pic_load(s, x, y, comp, req_comp, ri); -#endif + // test the formats with a very explicit header first (at least a FOURCC + // or distinctive magic number first) + #ifndef STBI_NO_PNG + if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_GIF + if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PSD + if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); + #else + STBI_NOTUSED(bpc); + #endif + #ifndef STBI_NO_PIC + if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); + #endif -// then the formats that can end up attempting to load with just 1 or 2 -// bytes matching expectations; these are prone to false positives, so -// try them later -#ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) - return stbi__jpeg_load(s, x, y, comp, req_comp, ri); -#endif -#ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) - return stbi__pnm_load(s, x, y, comp, req_comp, ri); -#endif + // then the formats that can end up attempting to load with just 1 or 2 + // bytes matching expectations; these are prone to false positives, so + // try them later + #ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); + #endif -#ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float * hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } -#endif + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } + #endif -#ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s, x, y, comp, req_comp, ri); -#endif + #ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s,x,y,comp,req_comp, ri); + #endif - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); + return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); } -static stbi_uc * stbi__convert_16_to_8(stbi__uint16 * orig, int w, int h, int channels) { - int i; - int img_len = w * h * channels; - stbi_uc * reduced; +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi_uc *reduced; - reduced = (stbi_uc *)stbi__malloc(img_len); - if (reduced == NULL) - return stbi__errpuc("outofmem", "Out of memory"); + reduced = (stbi_uc *) stbi__malloc(img_len); + if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + for (i = 0; i < img_len; ++i) + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - STBI_FREE(orig); - return reduced; + STBI_FREE(orig); + return reduced; } -static stbi__uint16 * stbi__convert_8_to_16(stbi_uc * orig, int w, int h, int channels) { - int i; - int img_len = w * h * channels; - stbi__uint16 * enlarged; +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; - enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); - if (enlarged == NULL) - return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); + if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff - STBI_FREE(orig); - return enlarged; + STBI_FREE(orig); + return enlarged; } -static void stbi__vertical_flip(void * image, int w, int h, int bytes_per_pixel) { - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc * bytes = (stbi_uc *)image; +static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) +{ + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; - for (row = 0; row < (h >> 1); row++) { - stbi_uc * row0 = bytes + row * bytes_per_row; - stbi_uc * row1 = bytes + (h - row - 1) * bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + for (row = 0; row < (h>>1); row++) { + stbi_uc *row0 = bytes + row*bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void * image, int w, int h, int z, int bytes_per_pixel) { - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) +{ + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc * bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char * stbi__load_and_postprocess_8bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *)result; + return (unsigned char *) result; } -static stbi__uint16 * stbi__load_and_postprocess_16bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - stbi__result_info ri; - void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *)result; + return (stbi__uint16 *) result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float * result, int * x, int * y, int * comp, int req_comp) { - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) +{ + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char * str, - int cbmb, wchar_t * widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, - const wchar_t * widestr, int cchwide, char * str, int cbmb, - const char * defchar, int * used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input) { - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int)bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); } #endif -static FILE * stbi__fopen(char const * filename, char const * mode) { - FILE * f; +static FILE *stbi__fopen(char const *filename, char const *mode) +{ + FILE *f; #if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename) / sizeof(*wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode) / sizeof(*wMode))) - return 0; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; #if defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f = 0; + if (0 != fopen_s(&f, filename, mode)) + f=0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; + return f; } -STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - unsigned char * result; - if (!f) - return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f, x, y, comp, req_comp); - fclose(f); - return result; + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; } -STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { - unsigned char * result; - stbi__context s; - stbi__start_file(&s, f); - result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; } -STBIDEF stbi__uint16 * stbi_load_from_file_16(FILE * f, int * x, int * y, int * comp, int req_comp) { - stbi__uint16 * result; - stbi__context s; - stbi__start_file(&s, f); - result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; } -STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * comp, int req_comp) { - FILE * f = stbi__fopen(filename, "rb"); - stbi__uint16 * result; - if (!f) - return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f, x, y, comp, req_comp); - fclose(f); - return result; +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f,x,y,comp,req_comp); + fclose(f); + return result; } -#endif //! STBI_NO_STDIO -STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, - int desired_channels) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +#endif //!STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); } -STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, - int * channels_in_file, int desired_channels) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); } -STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); } -STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, - int * comp, int req_comp) { - unsigned char * result; - stbi__context s; - stbi__start_mem(&s, buffer, len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_mem(&s,buffer,len); - result = (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices(result, *x, *y, *z, *comp); - } + result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float * stbi__loadf_main(stbi__context * s, int * x, int * y, int * comp, int req_comp) { - unsigned char * data; -#ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float * hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data, x, y, comp, req_comp); - return hdr_data; - } -#endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *data; + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data,x,y,comp,req_comp); + return hdr_data; + } + #endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); } -STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__loadf_main(&s, x, y, comp, req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__loadf_main(&s,x,y,comp,req_comp); } -STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, - int req_comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__loadf_main(&s, x, y, comp, req_comp); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__loadf_main(&s,x,y,comp,req_comp); } #ifndef STBI_NO_STDIO -STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * comp, int req_comp) { - float * result; - FILE * f = stbi__fopen(filename, "rb"); - if (!f) - return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f, x, y, comp, req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; } -STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { - stbi__context s; - stbi__start_file(&s, f); - return stbi__loadf_main(&s, x, y, comp, req_comp); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_file(&s,f); + return stbi__loadf_main(&s,x,y,comp,req_comp); } #endif // !STBI_NO_STDIO @@ -1468,208 +1514,222 @@ STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len) { -#ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__hdr_test(&s); -#else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; + #endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); - int result = 0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr (char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result=0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE * f) { -#ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s, f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; -#else - STBI_NOTUSED(f); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) +{ + #ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s,f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; + #else + STBI_NOTUSED(f); + return 0; + #endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user) { -#ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__hdr_test(&s); -#else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; -#endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; + #endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; +static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } #endif -static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; + +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1 / gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1 / scale; } ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; +enum +{ + STBI__SCAN_load=0, + STBI__SCAN_type, + STBI__SCAN_header +}; -static void stbi__refill_buffer(stbi__context * s) { - int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); - s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + 1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } +static void stbi__refill_buffer(stbi__context *s) +{ + int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); + s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start+1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } } -stbi_inline static stbi_uc stbi__get8(stbi__context * s) { - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; +stbi_inline static stbi_uc stbi__get8(stbi__context *s) +{ + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; } #if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context * s) { - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) - return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) - return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) +{ + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ - defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context * s, int n) { - if (n == 0) - return; // already there! - if (n < 0) { - s->img_buffer = s->img_buffer_end; - return; - } - if (s->io.read) { - int blen = (int)(s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; +static void stbi__skip(stbi__context *s, int n) +{ + if (n == 0) return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context * s, stbi_uc * buffer, int n) { - if (s->io.read) { - int blen = (int)(s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) +{ + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); - res = (count == (n - blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); + res = (count == (n-blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer + n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer+n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif #if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context * s) { - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) +{ + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context * s) { - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) +{ + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context * s) { - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) +{ + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context * s) { - stbi__uint32 z = stbi__get16le(s); - z += (stbi__uint32)stbi__get16le(s) << 16; - return z; +static stbi__uint32 stbi__get32le(stbi__context *s) +{ + stbi__uint32 z = stbi__get16le(s); + z += (stbi__uint32)stbi__get16le(s) << 16; + return z; } #endif -#define STBI__BYTECAST(x) ((stbi_uc)((x)&255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ - defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1683,264 +1743,169 @@ static stbi__uint32 stbi__get32le(stbi__context * s) { // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) { return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi_uc stbi__compute_y(int r, int g, int b) +{ + return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +} #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ - defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else -static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int req_comp, unsigned int x, unsigned int y) { - int i, j; - unsigned char * good; +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + unsigned char *good; - if (req_comp == img_n) - return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } + good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } - for (j = 0; j < (int)y; ++j) { - unsigned char * src = data + j * x * img_n; - unsigned char * dest = good + j * x * req_comp; + for (j=0; j < (int) y; ++j) { + unsigned char *src = data + j * x * img_n ; + unsigned char *dest = good + j * x * req_comp; -#define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ - for (i = x - 1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1, 2) { - dest[0] = src[0]; - dest[1] = 255; - } - break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(1, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = 255; - } - break; - STBI__CASE(2, 1) { dest[0] = src[0]; } - break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(2, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = src[1]; - } - break; - STBI__CASE(3, 4) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - dest[3] = 255; - } - break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - break; - STBI__CASE(3, 2) { - dest[0] = stbi__compute_y(src[0], src[1], src[2]); - dest[1] = 255; - } - break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - break; - STBI__CASE(4, 2) { - dest[0] = stbi__compute_y(src[0], src[1], src[2]); - dest[1] = src[3]; - } - break; - STBI__CASE(4, 3) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - } - break; - default: - STBI_ASSERT(0); - STBI_FREE(data); - STBI_FREE(good); - return stbi__errpuc("unsupported", "Unsupported format conversion"); - } -#undef STBI__CASE - } + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) +{ + return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +} #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int req_comp, unsigned int x, unsigned int y) { - int i, j; - stbi__uint16 * good; +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + stbi__uint16 *good; - if (req_comp == img_n) - return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); - } + good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + } - for (j = 0; j < (int)y; ++j) { - stbi__uint16 * src = data + j * x * img_n; - stbi__uint16 * dest = good + j * x * req_comp; + for (j=0; j < (int) y; ++j) { + stbi__uint16 *src = data + j * x * img_n ; + stbi__uint16 *dest = good + j * x * req_comp; -#define STBI__COMBO(a, b) ((a)*8 + (b)) -#define STBI__CASE(a, b) \ - case STBI__COMBO(a, b): \ - for (i = x - 1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1, 2) { - dest[0] = src[0]; - dest[1] = 0xffff; - } - break; - STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(1, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = 0xffff; - } - break; - STBI__CASE(2, 1) { dest[0] = src[0]; } - break; - STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } - break; - STBI__CASE(2, 4) { - dest[0] = dest[1] = dest[2] = src[0]; - dest[3] = src[1]; - } - break; - STBI__CASE(3, 4) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - dest[3] = 0xffff; - } - break; - STBI__CASE(3, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - break; - STBI__CASE(3, 2) { - dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); - dest[1] = 0xffff; - } - break; - STBI__CASE(4, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - break; - STBI__CASE(4, 2) { - dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); - dest[1] = src[3]; - } - break; - STBI__CASE(4, 3) { - dest[0] = src[0]; - dest[1] = src[1]; - dest[2] = src[2]; - } - break; - default: - STBI_ASSERT(0); - STBI_FREE(data); - STBI_FREE(good); - return (stbi__uint16 *)stbi__errpuc("unsupported", "Unsupported format conversion"); - } -#undef STBI__CASE - } + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp) { - int i, k, n; - float * output; - if (!data) - return NULL; - output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { - STBI_FREE(data); - return stbi__errpf("outofmem", "Out of memory"); - } - // compute number of non-alpha components - if (comp & 1) - n = comp; - else - n = comp - 1; - for (i = 0; i < x * y; ++i) { - for (k = 0; k < n; ++k) { - output[i * comp + k] = (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i = 0; i < x * y; ++i) { - output[i * comp + n] = data[i * comp + n] / 255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) +{ + int i,k,n; + float *output; + if (!data) return NULL; + output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); + } + } + if (n < comp) { + for (i=0; i < x*y; ++i) { + output[i*comp + n] = data[i*comp + n]/255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int)(x)) -static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { - int i, k, n; - stbi_uc * output; - if (!data) - return NULL; - output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - // compute number of non-alpha components - if (comp & 1) - n = comp; - else - n = comp - 1; - for (i = 0; i < x * y; ++i) { - for (k = 0; k < n; ++k) { - float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) - z = 0; - if (z > 255) - z = 255; - output[i * comp + k] = (stbi_uc)stbi__float2int(z); - } - if (k < comp) { - float z = data[i * comp + k] * 255 + 0.5f; - if (z < 0) - z = 0; - if (z > 255) - z = 255; - output[i * comp + k] = (stbi_uc)stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int) (x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) +{ + int i,k,n; + stbi_uc *output; + if (!data) return NULL; + output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + if (k < comp) { + float z = data[i*comp+k] * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1968,783 +1933,763 @@ static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache -typedef struct { - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +typedef struct +{ + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct { - stbi__context * s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; +typedef struct +{ + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; - // sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; +// sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; - // definition of jpeg image component - struct { - int id; - int h, v; - int tq; - int hd, ha; - int dc_pred; +// definition of jpeg image component + struct + { + int id; + int h,v; + int tq; + int hd,ha; + int dc_pred; - int x, y, w2, h2; - stbi_uc * data; - void *raw_data, *raw_coeff; - stbi_uc * linebuf; - short * coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; + int x,y,w2,h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; - int scan_n, order[4]; - int restart_interval, todo; + int scan_n, order[4]; + int restart_interval, todo; - // kernels - void (*idct_block_kernel)(stbi_uc * out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, - int step); - stbi_uc * (*resample_row_hv_2_kernel)(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs); +// kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman * h, int * count) { - int i, j, k = 0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i = 0; i < 16; ++i) { - for (j = 0; j < count[i]; ++j) { - h->size[k++] = (stbi_uc)(i + 1); - if (k >= 257) - return stbi__err("bad size list", "Corrupt JPEG"); - } - } - h->size[k] = 0; +static int stbi__build_huffman(stbi__huffman *h, int *count) +{ + int i,j,k=0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i=0; i < 16; ++i) { + for (j=0; j < count[i]; ++j) { + h->size[k++] = (stbi_uc) (i+1); + if(k >= 257) return stbi__err("bad size list","Corrupt JPEG"); + } + } + h->size[k] = 0; - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for (j = 1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16)(code++); - if (code - 1 >= (1u << j)) - return stbi__err("bad code lengths", "Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16 - j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for(j=1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16) (code++); + if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16-j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i = 0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS - s); - int m = 1 << (FAST_BITS - s); - for (j = 0; j < m; ++j) { - h->fast[c + j] = (stbi_uc)i; - } - } - } - return 1; + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i=0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS-s); + int m = 1 << (FAST_BITS-s); + for (j=0; j < m; ++j) { + h->fast[c+j] = (stbi_uc) i; + } + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 * fast_ac, stbi__huffman * h) { - int i; - for (i = 0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) +{ + int i; + for (i=0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) - k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); - } - } - } + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); + } + } + } } -static void stbi__grow_buffer_unsafe(stbi__jpeg * j) { - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) - c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char)c; - j->nomore = 1; - return; - } - } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) +{ + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char) c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17] = {0, 1, 3, 7, 15, 31, 63, 127, 255, - 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; +static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg * j, stbi__huffman * h) { - unsigned int temp; - int c, k; +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) +{ + unsigned int temp; + int c,k; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k = FAST_BITS + 1;; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k=FAST_BITS+1 ; ; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } - if (k > j->code_bits) - return -1; + if (k > j->code_bits) + return -1; - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - if (c < 0 || c >= 256) // symbol id out of bounds! - return -1; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + if(c < 0 || c >= 256) // symbol id out of bounds! + return -1; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) +{ + unsigned int k; + int sgn; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing - sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & (sgn - 1)); + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & (sgn - 1)); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg * j, int n) { - unsigned int k; - if (j->code_bits < n) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < n) - return 0; // ran out of bits from stream, return 0s intead of continuing - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) +{ + unsigned int k; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; } -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg * j) { - unsigned int k; - if (j->code_bits < 1) - stbi__grow_buffer_unsafe(j); - if (j->code_bits < 1) - return 0; // ran out of bits from stream, return 0s intead of continuing - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) +{ + unsigned int k; + if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); + if (j->code_bits < 1) return 0; // ran out of bits from stream, return 0s intead of continuing + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { - 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, - 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; +static const stbi_uc stbi__jpeg_dezigzag[64+15] = +{ + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63 +}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman * hdc, stbi__huffman * hac, stbi__int16 * fac, - int b, stbi__uint16 * dequant) { - int diff, dc, k; - int t; +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) +{ + int diff,dc,k; + int t; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0 || t > 15) - return stbi__err("bad huffman code", "Corrupt JPEG"); + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("bad huffman code","Corrupt JPEG"); - // 0 all the ac values now so we can do it 32-bits at a time - memset(data, 0, 64 * sizeof(data[0])); + // 0 all the ac values now so we can do it 32-bits at a time + memset(data,0,64*sizeof(data[0])); - diff = t ? stbi__extend_receive(j, t) : 0; - if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) - return stbi__err("bad delta", "Corrupt JPEG"); - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - if (!stbi__mul2shorts_valid(dc, dequant[0])) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - data[0] = (short)(dc * dequant[0]); + diff = t ? stbi__extend_receive(j, t) : 0; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta","Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, dequant[0])) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * dequant[0]); - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c, r, s; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - if (s > j->code_bits) - return stbi__err("bad huffman code", "Combined length longer than code bits available"); - j->code_buffer <<= s; - j->code_bits -= s; + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) break; // end block + k += 16; + } else { + k += r; // decode into unzigzag'd location zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)((r >> 8) * dequant[zig]); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) - break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); - } - } - } while (k < 64); - return 1; + data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); + } + } + } while (k < 64); + return 1; } -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi__huffman * hdc, int b) { - int diff, dc; - int t; - if (j->spec_end != 0) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) +{ + int diff,dc; + int t; + if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0 || t > 15) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data,0,64*sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; - if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) - return stbi__err("bad delta", "Corrupt JPEG"); - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - data[0] = (short)(dc * (1 << j->succ_low)); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short)(1 << j->succ_low); - } - return 1; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * (1 << j->succ_low)); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short) (1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi__huffman * hac, stbi__int16 * fac) { - int k; - if (j->spec_start == 0) - return stbi__err("can't merge dc and ac", "Corrupt JPEG"); +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) +{ + int k; + if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - if (j->succ_high == 0) { - int shift = j->succ_low; + if (j->succ_high == 0) { + int shift = j->succ_low; - if (j->eob_run) { - --j->eob_run; - return 1; - } + if (j->eob_run) { + --j->eob_run; + return 1; + } - k = j->spec_start; - do { - unsigned int zig; - int c, r, s; - if (j->code_bits < 16) - stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - if (s > j->code_bits) - return stbi__err("bad huffman code", "Combined length longer than code bits available"); - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)((r >> 8) * (1 << shift)); + k = j->spec_start; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * (1 << shift)); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short)(stbi__extend_receive(j, s) * (1 << shift)); - } + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * (1 << shift)); } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients + } + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients - short bit = (short)(1 << j->succ_low); + short bit = (short) (1 << j->succ_low); - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short * p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit) == 0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { + k = j->spec_start; + do { + int r,s; + int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; } - } else { - k = j->spec_start; - do { - int r, s; - int rs = stbi__jpeg_huff_decode( - j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) - return stbi__err("bad huffman code", "Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) - return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } - // advance by r - while (k <= j->spec_end) { - short * p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit) == 0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short)s; - break; - } - --r; - } - } - } while (k <= j->spec_end); - } - } - return 1; + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short) s; + break; + } + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) { - // trick to use a single test to catch both cases - if ((unsigned int)x > 255) { - if (x < 0) - return 0; - if (x > 255) - return 255; - } - return (stbi_uc)x; +stbi_inline static stbi_uc stbi__clamp(int x) +{ + // trick to use a single test to catch both cases + if ((unsigned int) x > 255) { + if (x < 0) return 0; + if (x > 255) return 255; + } + return (stbi_uc) x; } -#define stbi__f2f(x) ((int)(((x)*4096 + 0.5))) -#define stbi__fsh(x) ((x)*4096) +#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ - int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ - t3 = p1 + p2 * stbi__f2f(0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2 + p3); \ - t1 = stbi__fsh(p2 - p3); \ - x0 = t0 + t3; \ - x3 = t0 - t3; \ - x1 = t1 + t2; \ - x2 = t1 - t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0 + t2; \ - p4 = t1 + t3; \ - p1 = t0 + t3; \ - p2 = t1 + t2; \ - p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ - t0 = t0 * stbi__f2f(0.298631336f); \ - t1 = t1 * stbi__f2f(2.053119869f); \ - t2 = t2 * stbi__f2f(3.072711026f); \ - t3 = t3 * stbi__f2f(1.501321110f); \ - p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ - p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ - p3 = p3 * stbi__f2f(-1.961570560f); \ - p4 = p4 * stbi__f2f(-0.390180644f); \ - t3 += p1 + p4; \ - t2 += p2 + p3; \ - t1 += p2 + p4; \ - t0 += p1 + p3; +#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ + int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2+p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3*stbi__f2f(-1.847759065f); \ + t3 = p1 + p2*stbi__f2f( 0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2+p3); \ + t1 = stbi__fsh(p2-p3); \ + x0 = t0+t3; \ + x3 = t0-t3; \ + x1 = t1+t2; \ + x2 = t1-t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0+t2; \ + p4 = t1+t3; \ + p1 = t0+t3; \ + p2 = t1+t2; \ + p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ + t0 = t0*stbi__f2f( 0.298631336f); \ + t1 = t1*stbi__f2f( 2.053119869f); \ + t2 = t2*stbi__f2f( 3.072711026f); \ + t3 = t3*stbi__f2f( 1.501321110f); \ + p1 = p5 + p1*stbi__f2f(-0.899976223f); \ + p2 = p5 + p2*stbi__f2f(-2.562915447f); \ + p3 = p3*stbi__f2f(-1.961570560f); \ + p4 = p4*stbi__f2f(-0.390180644f); \ + t3 += p1+p4; \ + t2 += p2+p3; \ + t1 += p2+p4; \ + t0 += p1+p3; -static void stbi__idct_block(stbi_uc * out, int out_stride, short data[64]) { - int i, val[64], *v = val; - stbi_uc * o; - short * d = data; +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) +{ + int i,val[64],*v=val; + stbi_uc *o; + short *d = data; - // columns - for (i = 0; i < 8; ++i, ++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && d[48] == 0 && d[56] == 0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0] * 4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; - x1 += 512; - x2 += 512; - x3 += 512; - v[0] = (x0 + t3) >> 10; - v[56] = (x0 - t3) >> 10; - v[8] = (x1 + t2) >> 10; - v[48] = (x1 - t2) >> 10; - v[16] = (x2 + t1) >> 10; - v[40] = (x2 - t1) >> 10; - v[24] = (x3 + t0) >> 10; - v[32] = (x3 - t0) >> 10; - } - } + // columns + for (i=0; i < 8; ++i,++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 + && d[40]==0 && d[48]==0 && d[56]==0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0]*4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; x1 += 512; x2 += 512; x3 += 512; + v[ 0] = (x0+t3) >> 10; + v[56] = (x0-t3) >> 10; + v[ 8] = (x1+t2) >> 10; + v[48] = (x1-t2) >> 10; + v[16] = (x2+t1) >> 10; + v[40] = (x2-t1) >> 10; + v[24] = (x3+t0) >> 10; + v[32] = (x3-t0) >> 10; + } + } - for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128 << 17); - x1 += 65536 + (128 << 17); - x2 += 65536 + (128 << 17); - x3 += 65536 + (128 << 17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0 + t3) >> 17); - o[7] = stbi__clamp((x0 - t3) >> 17); - o[1] = stbi__clamp((x1 + t2) >> 17); - o[6] = stbi__clamp((x1 - t2) >> 17); - o[2] = stbi__clamp((x2 + t1) >> 17); - o[5] = stbi__clamp((x2 - t1) >> 17); - o[3] = stbi__clamp((x3 + t0) >> 17); - o[4] = stbi__clamp((x3 - t0) >> 17); - } + for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128<<17); + x1 += 65536 + (128<<17); + x2 += 65536 + (128<<17); + x3 += 65536 + (128<<17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0+t3) >> 17); + o[7] = stbi__clamp((x0-t3) >> 17); + o[1] = stbi__clamp((x1+t2) >> 17); + o[6] = stbi__clamp((x1-t2) >> 17); + o[2] = stbi__clamp((x2+t1) >> 17); + o[5] = stbi__clamp((x2-t1) >> 17); + o[3] = stbi__clamp((x3+t0) >> 17); + o[4] = stbi__clamp((x3-t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; -// dot product constant: even elems=x, odd elems=y -#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + // dot product constant: even elems=x, odd elems=y + #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) -// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) -// out(1) = c1[even]*x + c1[odd]*y -#define dct_rot(out0, out1, x, y, c0, c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) + // out(1) = c1[even]*x + c1[odd]*y + #define dct_rot(out0,out1, x,y,c0,c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) -// out = in << 12 (in 16-bit, out 32-bit) -#define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + // out = in << 12 (in 16-bit, out 32-bit) + #define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) -// wide add -#define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + // wide add + #define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) -// wide sub -#define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + // wide sub + #define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) -// butterfly a/b, add bias, then shift by "s" and pack -#define dct_bfly32o(out0, out1, a, b, bias, s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } + // butterfly a/b, add bias, then shift by "s" and pack + #define dct_bfly32o(out0, out1, a,b,bias,s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } -// 8-bit interleave step (for transposes) -#define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) + // 8-bit interleave step (for transposes) + #define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) -// 16-bit interleave step (for transposes) -#define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) + // 16-bit interleave step (for transposes) + #define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) -#define dct_pass(bias, shift) \ - { \ - /* even part */ \ - dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ - dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0, row7, x0, x7, bias, shift); \ - dct_bfly32o(row1, row6, x1, x6, bias, shift); \ - dct_bfly32o(row2, row5, x2, x5, bias, shift); \ - dct_bfly32o(row3, row4, x3, x4, bias, shift); \ - } + #define dct_pass(bias,shift) \ + { \ + /* even part */ \ + dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ + dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0,row7, x0,x7,bias,shift); \ + dct_bfly32o(row1,row6, x1,x6,bias,shift); \ + dct_bfly32o(row2,row5, x2,x5,bias,shift); \ + dct_bfly32o(row3,row4, x3,x4,bias,shift); \ + } - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - // load - row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); - row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); - row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); - row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); - row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); - row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); - row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); - row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + // load + row0 = _mm_load_si128((const __m128i *) (data + 0*8)); + row1 = _mm_load_si128((const __m128i *) (data + 1*8)); + row2 = _mm_load_si128((const __m128i *) (data + 2*8)); + row3 = _mm_load_si128((const __m128i *) (data + 3*8)); + row4 = _mm_load_si128((const __m128i *) (data + 4*8)); + row5 = _mm_load_si128((const __m128i *) (data + 5*8)); + row6 = _mm_load_si128((const __m128i *) (data + 6*8)); + row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - // column pass - dct_pass(bias_0, 10); + // column pass + dct_pass(bias_0, 10); - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } - // row pass - dct_pass(bias_1, 17); + // row pass + dct_pass(bias_1, 17); - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... - // store - _mm_storel_epi64((__m128i *)out, p0); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p2); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p1); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); - out += out_stride; - _mm_storel_epi64((__m128i *)out, p3); - out += out_stride; - _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); - } + // store + _mm_storel_epi64((__m128i *) out, p0); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p2); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p1); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p3); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2763,235 +2708,198 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } +#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ - dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ - dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ - dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ - } +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ + dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ + dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ + dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ + } - // load - row0 = vld1q_s16(data + 0 * 8); - row1 = vld1q_s16(data + 1 * 8); - row2 = vld1q_s16(data + 2 * 8); - row3 = vld1q_s16(data + 3 * 8); - row4 = vld1q_s16(data + 4 * 8); - row5 = vld1q_s16(data + 5 * 8); - row6 = vld1q_s16(data + 6 * 8); - row7 = vld1q_s16(data + 7 * 8); + // load + row0 = vld1q_s16(data + 0*8); + row1 = vld1q_s16(data + 1*8); + row2 = vld1q_s16(data + 2*8); + row3 = vld1q_s16(data + 3*8); + row4 = vld1q_s16(data + 4*8); + row5 = vld1q_s16(data + 5*8); + row6 = vld1q_s16(data + 6*8); + row7 = vld1q_s16(data + 7*8); - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - // column pass - dct_pass(vrshrn_n_s32, 10); + // column pass + dct_pass(vrshrn_n_s32, 10); - // 16bit 8x8 transpose - { + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) \ - { \ - int16x8x2_t t = vtrnq_s16(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ - } -#define dct_trn32(x, y) \ - { \ - int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ - x = vreinterpretq_s16_s32(t.val[0]); \ - y = vreinterpretq_s16_s32(t.val[1]); \ - } -#define dct_trn64(x, y) \ - { \ - int16x8_t x0 = x; \ - int16x8_t y0 = y; \ - x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ - y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ - } +#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } +#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } + } - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) \ - { \ - uint8x8x2_t t = vtrn_u8(x, y); \ - x = t.val[0]; \ - y = t.val[1]; \ - } -#define dct_trn8_16(x, y) \ - { \ - uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ - x = vreinterpret_u8_u16(t.val[0]); \ - y = vreinterpret_u8_u16(t.val[1]); \ - } -#define dct_trn8_32(x, y) \ - { \ - uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ - x = vreinterpret_u8_u32(t.val[0]); \ - y = vreinterpret_u8_u32(t.val[1]); \ - } + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } +#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); - // store - vst1_u8(out, p0); - out += out_stride; - vst1_u8(out, p1); - out += out_stride; - vst1_u8(out, p2); - out += out_stride; - vst1_u8(out, p3); - out += out_stride; - vst1_u8(out, p4); - out += out_stride; - vst1_u8(out, p5); - out += out_stride; - vst1_u8(out, p6); - out += out_stride; - vst1_u8(out, p7); + // store + vst1_u8(out, p0); out += out_stride; + vst1_u8(out, p1); out += out_stride; + vst1_u8(out, p2); out += out_stride; + vst1_u8(out, p3); out += out_stride; + vst1_u8(out, p4); out += out_stride; + vst1_u8(out, p5); out += out_stride; + vst1_u8(out, p6); out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -3004,1267 +2912,1169 @@ static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg * j) { - stbi_uc x; - if (j->marker != STBI__MARKER_none) { - x = j->marker; - j->marker = STBI__MARKER_none; - return x; - } - x = stbi__get8(j->s); - if (x != 0xff) - return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) +{ + stbi_uc x; + if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } + x = stbi__get8(j->s); + if (x != 0xff) return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg * j) { - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels +static void stbi__jpeg_reset(stbi__jpeg *j) +{ + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels } -static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i, j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, z->fast_ac[ha], n, - z->dequant[z->img_comp[n].tq])) - return 0; - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) +{ + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i,j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } else { // interleaved - int i, j, k, x, y; - STBI_SIMD_ALIGN(short, data[64]); - for (j = 0; j < z->img_mcu_y; ++j) { - for (i = 0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k = 0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y = 0; y < z->img_comp[n].v; ++y) { - for (x = 0; x < z->img_comp[n].h; ++x) { - int x2 = (i * z->img_comp[n].h + x) * 8; - int y2 = (j * z->img_comp[n].v + y) * 8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, - z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) - return 0; - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * y2 + x2, z->img_comp[n].w2, - data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } - } else { - if (z->scan_n == 1) { - int i, j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { + } + return 1; + } else { // interleaved + int i,j,k,x,y; + STBI_SIMD_ALIGN(short, data[64]); + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x)*8; + int y2 = (j*z->img_comp[n].v + y)*8; int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } else { // interleaved - int i, j, k, x, y; - for (j = 0; j < z->img_mcu_y; ++j) { - for (i = 0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k = 0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y = 0; y < z->img_comp[n].v; ++y) { - for (x = 0; x < z->img_comp[n].h; ++x) { - int x2 = (i * z->img_comp[n].h + x); - int y2 = (j * z->img_comp[n].v + y); - short * data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) - stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) - return 1; - stbi__jpeg_reset(z); - } - } + } + return 1; + } + } else { + if (z->scan_n == 1) { + int i,j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } } - return 1; - } - } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x); + int y2 = (j*z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } } -static void stbi__jpeg_dequantize(short * data, stbi__uint16 * dequant) { - int i; - for (i = 0; i < 64; ++i) - data[i] *= dequant[i]; +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) +{ + int i; + for (i=0; i < 64; ++i) + data[i] *= dequant[i]; } -static void stbi__jpeg_finish(stbi__jpeg * z) { - if (z->progressive) { - // dequantize and idct the data - int i, j, n; - for (n = 0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x + 7) >> 3; - int h = (z->img_comp[n].y + 7) >> 3; - for (j = 0; j < h; ++j) { - for (i = 0; i < w; ++i) { - short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); - } +static void stbi__jpeg_finish(stbi__jpeg *z) +{ + if (z->progressive) { + // dequantize and idct the data + int i,j,n; + for (n=0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); } - } - } + } + } + } } -static int stbi__process_marker(stbi__jpeg * z, int m) { - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker", "Corrupt JPEG"); +static int stbi__process_marker(stbi__jpeg *z, int m) +{ + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker","Corrupt JPEG"); - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) - return stbi__err("bad DRI len", "Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s) - 2; - while (L > 0) { + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s)-2; + while (L > 0) { int q = stbi__get8(z->s); int p = q >> 4, sixteen = (p != 0); - int t = q & 15, i; - if (p != 0 && p != 1) - return stbi__err("bad DQT type", "Corrupt JPEG"); - if (t > 3) - return stbi__err("bad DQT table", "Corrupt JPEG"); + int t = q & 15,i; + if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); + if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - for (i = 0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + for (i=0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); L -= (sixteen ? 129 : 65); - } - return L == 0; + } + return L==0; - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s) - 2; - while (L > 0) { - stbi_uc * v; - int sizes[16], i, n = 0; + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s)-2; + while (L > 0) { + stbi_uc *v; + int sizes[16],i,n=0; int q = stbi__get8(z->s); int tc = q >> 4; int th = q & 15; - if (tc > 1 || th > 3) - return stbi__err("bad DHT header", "Corrupt JPEG"); - for (i = 0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; + if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); + for (i=0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; } - if (n > 256) - return stbi__err("bad DHT header", "Corrupt JPEG"); // Loop over i < n would write past end of values! + if(n > 256) return stbi__err("bad DHT header","Corrupt JPEG"); // Loop over i < n would write past end of values! L -= 17; if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc + th, sizes)) - return 0; - v = z->huff_dc[th].values; + if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; + v = z->huff_dc[th].values; } else { - if (!stbi__build_huffman(z->huff_ac + th, sizes)) - return 0; - v = z->huff_ac[th].values; + if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; + v = z->huff_ac[th].values; } - for (i = 0; i < n; ++i) - v[i] = stbi__get8(z->s); + for (i=0; i < n; ++i) + v[i] = stbi__get8(z->s); if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); L -= n; - } - return L == 0; - } + } + return L==0; + } - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len", "Corrupt JPEG"); - else - return stbi__err("bad APP len", "Corrupt JPEG"); - } - L -= 2; + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len","Corrupt JPEG"); + else + return stbi__err("bad APP len","Corrupt JPEG"); + } + L -= 2; - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; - int ok = 1; - int i; - for (i = 0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; - int ok = 1; - int i; - for (i = 0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J','F','I','F','\0'}; + int ok = 1; + int i; + for (i=0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; + int ok = 1; + int i; + for (i=0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } - } + } + } - stbi__skip(z->s, L); - return 1; - } + stbi__skip(z->s, L); + return 1; + } - return stbi__err("unknown marker", "Corrupt JPEG"); + return stbi__err("unknown marker","Corrupt JPEG"); } // after we see SOS -static int stbi__process_scan_header(stbi__jpeg * z) { - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) - return stbi__err("bad SOS component count", "Corrupt JPEG"); - if (Ls != 6 + 2 * z->scan_n) - return stbi__err("bad SOS len", "Corrupt JPEG"); - for (i = 0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; - int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) - return 0; // no match - z->img_comp[which].hd = q >> 4; - if (z->img_comp[which].hd > 3) - return stbi__err("bad DC huff", "Corrupt JPEG"); - z->img_comp[which].ha = q & 15; - if (z->img_comp[which].ha > 3) - return stbi__err("bad AC huff", "Corrupt JPEG"); - z->order[i] = which; - } +static int stbi__process_scan_header(stbi__jpeg *z) +{ + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); + if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); + for (i=0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) return 0; // no match + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); + z->order[i] = which; + } - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); - } else { - if (z->spec_start != 0) - return stbi__err("bad SOS", "Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) - return stbi__err("bad SOS", "Corrupt JPEG"); - z->spec_end = 63; - } - } + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); + z->spec_end = 63; + } + } - return 1; + return 1; } -static int stbi__free_jpeg_components(stbi__jpeg * z, int ncomp, int why) { - int i; - for (i = 0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) +{ + int i; + for (i=0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; } -static int stbi__process_frame_header(stbi__jpeg * z, int scan) { - stbi__context * s = z->s; - int Lf, p, i, q, h_max = 1, v_max = 1, c; - Lf = stbi__get16be(s); - if (Lf < 11) - return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG - p = stbi__get8(s); - if (p != 8) - return stbi__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); - if (s->img_y == 0) - return stbi__err("no header height", - "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); - if (s->img_x == 0) - return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) - return stbi__err("bad component count", "Corrupt JPEG"); - s->img_n = c; - for (i = 0; i < c; ++i) { - z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } +static int stbi__process_frame_header(stbi__jpeg *z, int scan) +{ + stbi__context *s = z->s; + int Lf,p,i,q, h_max=1,v_max=1,c; + Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG + p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); + s->img_n = c; + for (i=0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } - if (Lf != 8 + 3 * s->img_n) - return stbi__err("bad SOF len", "Corrupt JPEG"); + if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - z->rgb = 0; - for (i = 0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = {'R', 'G', 'B'}; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); - if (!z->img_comp[i].h || z->img_comp[i].h > 4) - return stbi__err("bad H", "Corrupt JPEG"); - z->img_comp[i].v = q & 15; - if (!z->img_comp[i].v || z->img_comp[i].v > 4) - return stbi__err("bad V", "Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); - if (z->img_comp[i].tq > 3) - return stbi__err("bad TQ", "Corrupt JPEG"); - } + z->rgb = 0; + for (i=0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = { 'R', 'G', 'B' }; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); + } - if (scan != STBI__SCAN_load) - return 1; + if (scan != STBI__SCAN_load) return 1; - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) - return stbi__err("too large", "Image too large to decode"); + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - for (i = 0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) - h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) - v_max = z->img_comp[i].v; - } + for (i=0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; + } - // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios - // and I've never seen a non-corrupted JPEG file actually use them - for (i = 0; i < s->img_n; ++i) { - if (h_max % z->img_comp[i].h != 0) - return stbi__err("bad H", "Corrupt JPEG"); - if (v_max % z->img_comp[i].v != 0) - return stbi__err("bad V", "Corrupt JPEG"); - } + // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios + // and I've never seen a non-corrupted JPEG file actually use them + for (i=0; i < s->img_n; ++i) { + if (h_max % z->img_comp[i].h != 0) return stbi__err("bad H","Corrupt JPEG"); + if (v_max % z->img_comp[i].v != 0) return stbi__err("bad V","Corrupt JPEG"); + } - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - for (i = 0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); - } - } + for (i=0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) -#define stbi__SOF_progressive(x) ((x) == 0xc2) +#define stbi__SOF_progressive(x) ((x) == 0xc2) -static int stbi__decode_jpeg_header(stbi__jpeg * z, int scan) { - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) - return stbi__err("no SOI", "Corrupt JPEG"); - if (scan == STBI__SCAN_type) - return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z, m)) - return 0; - m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) - return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) - return 0; - return 1; +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) +{ + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); + if (scan == STBI__SCAN_type) return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z,m)) return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); + m = stbi__get_marker(z); + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) return 0; + return 1; } -static int stbi__skip_jpeg_junk_at_end(stbi__jpeg * j) { - // some JPEGs have junk at end, skip over it but if we find what looks - // like a valid marker, resume there - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - while (x == 255) { // might be a marker - if (stbi__at_eof(j->s)) - return STBI__MARKER_none; - x = stbi__get8(j->s); - if (x != 0x00 && x != 0xff) { - // not a stuffed zero or lead-in to another marker, looks - // like an actual marker, return it - return x; - } - // stuffed zero has x=0 now which ends the loop, meaning we go - // back to regular scan loop. - // repeated 0xff keeps trying to read the next byte of the marker. - } - } - return STBI__MARKER_none; +static stbi_uc stbi__skip_jpeg_junk_at_end(stbi__jpeg *j) +{ + // some JPEGs have junk at end, skip over it but if we find what looks + // like a valid marker, resume there + while (!stbi__at_eof(j->s)) { + stbi_uc x = stbi__get8(j->s); + while (x == 0xff) { // might be a marker + if (stbi__at_eof(j->s)) return STBI__MARKER_none; + x = stbi__get8(j->s); + if (x != 0x00 && x != 0xff) { + // not a stuffed zero or lead-in to another marker, looks + // like an actual marker, return it + return x; + } + // stuffed zero has x=0 now which ends the loop, meaning we go + // back to regular scan loop. + // repeated 0xff keeps trying to read the next byte of the marker. + } + } + return STBI__MARKER_none; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg * j) { - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) - return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) - return 0; - if (!stbi__parse_entropy_coded_data(j)) - return 0; - if (j->marker == STBI__MARKER_none) { - j->marker = stbi__skip_jpeg_junk_at_end(j); - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } +static int stbi__decode_jpeg_image(stbi__jpeg *j) +{ + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) return 0; + if (!stbi__parse_entropy_coded_data(j)) return 0; + if (j->marker == STBI__MARKER_none ) { + j->marker = stbi__skip_jpeg_junk_at_end(j); + // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 + } + m = stbi__get_marker(j); + if (STBI__RESTART(m)) m = stbi__get_marker(j); - if (STBI__RESTART(m)) - m = stbi__get_marker(j); - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) - return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) - return stbi__err("bad DNL height", "Corrupt JPEG"); - m = stbi__get_marker(j); - } else { - if (!stbi__process_marker(j, m)) - return 1; - m = stbi__get_marker(j); - } - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); + m = stbi__get_marker(j); + } else { + if (!stbi__process_marker(j, m)) return 1; + m = stbi__get_marker(j); + } + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) -typedef stbi_uc * (*resample_row_func)(stbi_uc * out, stbi_uc * in0, stbi_uc * in1, int w, int hs); +typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, + int w, int hs); -#define stbi__div4(x) ((stbi_uc)((x) >> 2)) +#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) -static stbi_uc * resample_row_1(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; } -static stbi_uc * stbi__resample_row_v_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i = 0; i < w; ++i) - out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); - return out; +static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i=0; i < w; ++i) + out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); + return out; } -static stbi_uc * stbi__resample_row_h_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate two samples horizontally for every one in input - int i; - stbi_uc * input = in_near; +static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } - out[0] = input[0]; - out[1] = stbi__div4(input[0] * 3 + input[1] + 2); - for (i = 1; i < w - 1; ++i) { - int n = 3 * input[i] + 2; - out[i * 2 + 0] = stbi__div4(n + input[i - 1]); - out[i * 2 + 1] = stbi__div4(n + input[i + 1]); - } - out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); - out[i * 2 + 1] = input[w - 1]; + out[0] = input[0]; + out[1] = stbi__div4(input[0]*3 + input[1] + 2); + for (i=1; i < w-1; ++i) { + int n = 3*input[i]+2; + out[i*2+0] = stbi__div4(n+input[i-1]); + out[i*2+1] = stbi__div4(n+input[i+1]); + } + out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); + out[i*2+1] = input[w-1]; - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); - return out; + return out; } -#define stbi__div16(x) ((stbi_uc)((x) >> 4)) +#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) -static stbi_uc * stbi__resample_row_hv_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate 2x2 samples for every one in input - int i, t0, t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); - return out; - } +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i,t0,t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } - t1 = 3 * in_near[0] + in_far[0]; - out[0] = stbi__div4(t1 + 2); - for (i = 1; i < w; ++i) { - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - } - out[w * 2 - 1] = stbi__div4(t1 + 2); + t1 = 3*in_near[0] + in_far[0]; + out[0] = stbi__div4(t1+2); + for (i=1; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } #if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // need to generate 2x2 samples for every one in input - int i = 0, t0, t1; +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i=0,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); - return out; - } + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } - t1 = 3 * in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w - 1) & ~7); i += 8) { + t1 = 3*in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w-1) & ~7); i += 8) { #if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *)(out + i * 2), outv); + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *) (out + i*2), outv); #elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i * 2, o); + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i*2, o); #endif - // "previous" value for next iter - t1 = 3 * in_near[i + 7] + in_far[i + 7]; - } + // "previous" value for next iter + t1 = 3*in_near[i+7] + in_far[i+7]; + } - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + t0 = t1; + t1 = 3*in_near[i] + in_far[i]; + out[i*2] = stbi__div16(3*t1 + t0 + 8); - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3 * in_near[i] + in_far[i]; - out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); - out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - } - out[w * 2 - 1] = stbi__div4(t1 + 2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } #endif -static stbi_uc * stbi__resample_row_generic(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { - // resample with nearest-neighbor - int i, j; - STBI_NOTUSED(in_far); - for (i = 0; i < w; ++i) - for (j = 0; j < hs; ++j) - out[i * hs + j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // resample with nearest-neighbor + int i,j; + STBI_NOTUSED(in_far); + for (i=0; i < w; ++i) + for (j=0; j < hs; ++j) + out[i*hs+j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int)((x)*4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, - int step) { - int i; - for (i = 0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding - int r, g, b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr * stbi__float2fixed(1.40200f); - g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb * stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned)r > 255) { - if (r < 0) - r = 0; - else - r = 255; - } - if ((unsigned)g > 255) { - if (g < 0) - g = 0; - else - g = 255; - } - if ((unsigned)b > 255) { - if (b < 0) - b = 0; - else - b = 255; - } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) +{ + int i; + for (i=0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc const * pcb, stbi_uc const * pcr, int count, - int step) { - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) +{ + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); - __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); - __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); - __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); - __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); - __m128i xw = _mm_set1_epi16(255); // alpha channel + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); + __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); + __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); + __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); + __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); + __m128i xw = _mm_set1_epi16(255); // alpha channel - for (; i + 7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + for (; i+7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); - // store - _mm_storeu_si128((__m128i *)(out + 0), o0); - _mm_storeu_si128((__m128i *)(out + 16), o1); - out += 32; - } - } + // store + _mm_storeu_si128((__m128i *) (out + 0), o0); + _mm_storeu_si128((__m128i *) (out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); - int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); - int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); - int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); + int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); + int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); + int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - for (; i + 7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + for (; i+7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8 * 4; - } - } + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8*4; + } + } #endif - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1 << 19); // rounding - int r, g, b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr * stbi__float2fixed(1.40200f); - g = y_fixed + cr * -stbi__float2fixed(0.71414f) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb * stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned)r > 255) { - if (r < 0) - r = 0; - else - r = 255; - } - if ((unsigned)g > 255) { - if (g < 0) - g = 0; - else - g = 255; - } - if ((unsigned)b > 255) { - if (b < 0) - b = 0; - else - b = 255; - } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg * j) { - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) +{ + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg * j) { stbi__free_jpeg_components(j, j->s->img_n, 0); } +static void stbi__cleanup_jpeg(stbi__jpeg *j) +{ + stbi__free_jpeg_components(j, j->s->img_n, 0); +} -typedef struct { - resample_row_func resample; - stbi_uc *line0, *line1; - int hs, vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct +{ + resample_row_func resample; + stbi_uc *line0,*line1; + int hs,vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { - unsigned int t = x * y + 128; - return (stbi_uc)((t + (t >> 8)) >> 8); +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) +{ + unsigned int t = x*y + 128; + return (stbi_uc) ((t + (t >>8)) >> 8); } -static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * comp, int req_comp) { - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) +{ + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe - // validate req_comp - if (req_comp < 0 || req_comp > 4) - return stbi__errpuc("bad req_comp", "Internal error"); + // validate req_comp + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { - stbi__cleanup_jpeg(z); - return NULL; - } + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; - // nothing to do if no components requested; check this now to avoid - // accessing uninitialized coutput[0] later - if (decode_n <= 0) { - stbi__cleanup_jpeg(z); - return NULL; - } + // nothing to do if no components requested; check this now to avoid + // accessing uninitialized coutput[0] later + if (decode_n <= 0) { stbi__cleanup_jpeg(z); return NULL; } - // resample and color-convert - { - int k; - unsigned int i, j; - stbi_uc * output; - stbi_uc * coutput[4] = {NULL, NULL, NULL, NULL}; + // resample and color-convert + { + int k; + unsigned int i,j; + stbi_uc *output; + stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - stbi__resample res_comp[4]; + stbi__resample res_comp[4]; - for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { - stbi__cleanup_jpeg(z); - return stbi__errpuc("outofmem", "Out of memory"); + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs-1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; + else r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + // now go ahead and resample + for (j=0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, + y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) - r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) - r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) - r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) - r->resample = z->resample_row_hv_2_kernel; - else - r->resample = stbi__resample_row_generic; - } - - // can't error after this so, this is safe - output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { - stbi__cleanup_jpeg(z); - return stbi__errpuc("outofmem", "Out of memory"); - } - - // now go ahead and resample - for (j = 0; j < z->s->img_y; ++j) { - stbi_uc * out = output + n * z->s->img_x * j; - for (k = 0; k < decode_n; ++k) { - stbi__resample * r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; - } - } - if (n >= 3) { - stbi_uc * y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i = 0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i = 0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; + } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else + for (i=0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i=0; i < z->s->img_x; ++i) + *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i=0; i < z->s->img_x; ++i, out += 2) { + out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } } else { - if (is_rgb) { - if (n == 1) - for (i = 0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i = 0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i = 0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i = 0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc * y = coutput[0]; - if (n == 1) - for (i = 0; i < z->s->img_x; ++i) - out[i] = y[i]; - else - for (i = 0; i < z->s->img_x; ++i) { - *out++ = y[i]; - *out++ = 255; - } - } + stbi_uc *y = coutput[0]; + if (n == 1) + for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; + else + for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } } - } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) - *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } + } + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } } -static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - unsigned char * result; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); - if (!j) - return stbi__errpuc("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x, y, comp, req_comp); - STBI_FREE(j); - return result; +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + unsigned char* result; + stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__errpuc("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x,y,comp,req_comp); + STBI_FREE(j); + return result; } -static int stbi__jpeg_test(stbi__context * s) { - int r; - stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); - if (!j) - return stbi__err("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; +static int stbi__jpeg_test(stbi__context *s) +{ + int r; + stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; } -static int stbi__jpeg_info_raw(stbi__jpeg * j, int * x, int * y, int * comp) { - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind(j->s); - return 0; - } - if (x) - *x = j->s->img_x; - if (y) - *y = j->s->img_y; - if (comp) - *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) +{ + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind( j->s ); + return 0; + } + if (x) *x = j->s->img_x; + if (y) *y = j->s->img_y; + if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; } -static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { - int result; - stbi__jpeg * j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); - if (!j) - return stbi__err("outofmem", "Out of memory"); - memset(j, 0, sizeof(stbi__jpeg)); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) +{ + int result; + stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -4278,81 +4088,84 @@ static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) #define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct { - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[STBI__ZNSYMS]; - stbi__uint16 value[STBI__ZNSYMS]; +typedef struct +{ + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[STBI__ZNSYMS]; + stbi__uint16 value[STBI__ZNSYMS]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) { - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); - return n; +stbi_inline static int stbi__bitreverse16(int n) +{ + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); + return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) { - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16 - bits); +stbi_inline static int stbi__bit_reverse(int v, int bits) +{ + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16-bits); } -static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, int num) { - int i, k = 0; - int code, next_code[16], sizes[17]; +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) +{ + int i,k=0; + int code, next_code[16], sizes[17]; - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i = 0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i = 1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i = 1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16)code; - z->firstsymbol[i] = (stbi__uint16)k; - code = (code + sizes[i]); - if (sizes[i]) - if (code - 1 >= (1 << i)) - return stbi__err("bad codelengths", "Corrupt PNG"); - z->maxcode[i] = code << (16 - i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i = 0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); - z->size[c] = (stbi_uc)s; - z->value[c] = (stbi__uint16)i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s], s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i=0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i=1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i=1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16) code; + z->firstsymbol[i] = (stbi__uint16) k; + code = (code + sizes[i]); + if (sizes[i]) + if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); + z->maxcode[i] = code << (16-i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i=0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); + z->size [c] = (stbi_uc ) s; + z->value[c] = (stbi__uint16) i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s],s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); } - ++next_code[s]; - } - } - return 1; + } + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4361,298 +4174,297 @@ static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, in // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct { - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct +{ + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + int hit_zeof_once; + stbi__uint32 code_buffer; - char * zout; - char * zout_start; - char * zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf * z) { return (z->zbuffer >= z->zbuffer_end); } - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf * z) { return stbi__zeof(z) ? 0 : *z->zbuffer++; } - -static void stbi__fill_bits(stbi__zbuf * z) { - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); -} - -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf * z, int n) { - unsigned int k; - if (z->num_bits < n) - stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; -} - -static int stbi__zhuffman_decode_slowpath(stbi__zbuf * a, stbi__zhuffman * z) { - int b, s, k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s = STBI__ZFAST_BITS + 1;; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) - return -1; // invalid code! - // code size is s, so: - b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= STBI__ZNSYMS) - return -1; // some data was corrupt somewhere! - if (z->size[b] != s) - return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf * a, stbi__zhuffman * z) { - int b, s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf * z, char * zout, int n) // need to make room for n bytes +stbi_inline static int stbi__zeof(stbi__zbuf *z) { - char * q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) - return stbi__err("output buffer limit", "Corrupt PNG"); - cur = (unsigned int)(z->zout - z->zout_start); - limit = old_limit = (unsigned)(z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned)n) - return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if (limit > UINT_MAX / 2) - return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) - return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; + return (z->zbuffer >= z->zbuffer_end); } -static const int stbi__zlength_base[31] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, - 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; - -static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, - 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0}; - -static const int stbi__zdist_base[32] = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, - 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, - 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; - -static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, - 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; - -static int stbi__parse_huffman_block(stbi__zbuf * a) { - char * zout = a->zout; - for (;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) - return stbi__err("bad huffman code", "Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) - return 0; - zout = a->zout; - } - *zout++ = (char)z; - } else { - stbi_uc * p; - int len, dist; - if (z == 256) { - a->zout = zout; - return 1; - } - if (z >= 286) - return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) - len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0 || z >= 30) - return stbi__err("bad huffman code", - "Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) - dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) - return stbi__err("bad dist", "Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) - return 0; - zout = a->zout; - } - p = (stbi_uc *)(zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { - do - *zout++ = v; - while (--len); - } - } else { - if (len) { - do - *zout++ = *p++; - while (--len); - } - } - } - } +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) +{ + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__compute_huffman_codes(stbi__zbuf * a) { - static const stbi_uc length_dezigzag[19] = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op - stbi_uc codelength_sizes[19]; - int i, n; +static void stbi__fill_bits(stbi__zbuf *z) +{ + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} - int hlit = stbi__zreceive(a, 5) + 257; - int hdist = stbi__zreceive(a, 5) + 1; - int hclen = stbi__zreceive(a, 4) + 4; - int ntot = hlit + hdist; +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) +{ + unsigned int k; + if (z->num_bits < n) stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i = 0; i < hclen; ++i) { - int s = stbi__zreceive(a, 3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) - return 0; +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s,k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s=STBI__ZFAST_BITS+1; ; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) return -1; // invalid code! + // code size is s, so: + b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= STBI__ZNSYMS) return -1; // some data was corrupt somewhere! + if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + if (!a->hit_zeof_once) { + // This is the first time we hit eof, insert 16 extra padding btis + // to allow us to keep going; if we actually consume any of them + // though, that is invalid data. This is caught later. + a->hit_zeof_once = 1; + a->num_bits += 16; // add 16 implicit zero bits + } else { + // We already inserted our extra 16 padding bits and are again + // out, this stream is actually prematurely terminated. + return -1; + } + } else { + stbi__fill_bits(a); + } + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); + cur = (unsigned int) (z->zout - z->zout_start); + limit = old_limit = (unsigned) (z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; +} + +static const int stbi__zlength_base[31] = { + 3,4,5,6,7,8,9,10,11,13, + 15,17,19,23,27,31,35,43,51,59, + 67,83,99,115,131,163,195,227,258,0,0 }; + +static const int stbi__zlength_extra[31]= +{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; + +static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, +257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; + +static const int stbi__zdist_extra[32] = +{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) +{ + char *zout = a->zout; + for(;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) return 0; + zout = a->zout; + } + *zout++ = (char) z; + } else { + stbi_uc *p; + int len,dist; + if (z == 256) { + a->zout = zout; + if (a->hit_zeof_once && a->num_bits < 16) { + // The first time we hit zeof, we inserted 16 extra zero bits into our bit + // buffer so the decoder can just do its speculative decoding. But if we + // actually consumed any of those bits (which is the case when num_bits < 16), + // the stream actually read past the end so it is malformed. + return stbi__err("unexpected end","Corrupt PNG"); + } + return 1; + } + if (z >= 286) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0 || z >= 30) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); + if (len > a->zout_end - zout) { + if (!stbi__zexpand(a, zout, len)) return 0; + zout = a->zout; + } + p = (stbi_uc *) (zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { do *zout++ = v; while (--len); } + } else { + if (len) { do *zout++ = *p++; while (--len); } + } + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) +{ + static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286+32+137];//padding for maximum single op + stbi_uc codelength_sizes[19]; + int i,n; + + int hlit = stbi__zreceive(a,5) + 257; + int hdist = stbi__zreceive(a,5) + 1; + int hclen = stbi__zreceive(a,4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i=0; i < hclen; ++i) { + int s = stbi__zreceive(a,3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc) c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a,2)+3; + if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n-1]; + } else if (c == 17) { + c = stbi__zreceive(a,3)+3; + } else if (c == 18) { + c = stbi__zreceive(a,7)+11; + } else { return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc)c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a, 2) + 3; - if (n == 0) - return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n - 1]; - } else if (c == 17) { - c = stbi__zreceive(a, 3) + 3; - } else if (c == 18) { - c = stbi__zreceive(a, 7) + 11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) - return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes + n, fill, c); - n += c; - } - } - if (n != ntot) - return stbi__err("bad codelengths", "Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) - return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) - return 0; - return 1; + } + if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes+n, fill, c); + n += c; + } + } + if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; + return 1; } -static int stbi__parse_uncompressed_block(stbi__zbuf * a) { - stbi_uc header[4]; - int len, nlen, k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) - return stbi__err("zlib corrupt", "Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) - return stbi__err("zlib corrupt", "Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) - return stbi__err("read past buffer", "Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) - return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; +static int stbi__parse_uncompressed_block(stbi__zbuf *a) +{ + stbi_uc header[4]; + int len,nlen,k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; } -static int stbi__parse_zlib_header(stbi__zbuf * a) { - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec - if ((cmf * 256 + flg) % 31 != 0) - return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec - if (flg & 32) - return stbi__err("no preset dict", "Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) - return stbi__err("bad compression", "Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; +static int stbi__parse_zlib_header(stbi__zbuf *a) +{ + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; } -static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = { - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; -static const stbi_uc stbi__zdefault_distance[32] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; +static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = +{ + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 +}; +static const stbi_uc stbi__zdefault_distance[32] = +{ + 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 +}; /* Init algorithm: { @@ -4666,122 +4478,118 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf * a, int parse_header) { - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) - return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a, 1); - type = stbi__zreceive(a, 2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) - return 0; - } else if (type == 3) { - return 0; - } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, STBI__ZNSYMS)) - return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) - return 0; - } else { - if (!stbi__compute_huffman_codes(a)) - return 0; - } - if (!stbi__parse_huffman_block(a)) - return 0; - } - } while (!final); - return 1; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) +{ + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) return 0; + a->num_bits = 0; + a->code_buffer = 0; + a->hit_zeof_once = 0; + do { + final = stbi__zreceive(a,1); + type = stbi__zreceive(a,2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , STBI__ZNSYMS)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; + } else { + if (!stbi__compute_huffman_codes(a)) return 0; + } + if (!stbi__parse_huffman_block(a)) return 0; + } + } while (!final); + return 1; } -static int stbi__do_zlib(stbi__zbuf * a, char * obuf, int olen, int exp, int parse_header) { - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) +{ + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; - return stbi__parse_zlib(a, parse_header); + return stbi__parse_zlib(a, parse_header); } -STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } } -STBIDEF char * stbi_zlib_decode_malloc(char const * buffer, int len, int * outlen) { - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) +{ + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); } -STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, - int parse_header) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(initial_size); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } } -STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, char const * ibuffer, int ilen) { - stbi__zbuf a; - a.zbuffer = (stbi_uc *)ibuffer; - a.zbuffer_end = (stbi_uc *)ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int)(a.zout - a.zout_start); - else - return -1; +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int) (a.zout - a.zout_start); + else + return -1; } -STBIDEF char * stbi_zlib_decode_noheader_malloc(char const * buffer, int len, int * outlen) { - stbi__zbuf a; - char * p = (char *)stbi__malloc(16384); - if (p == NULL) - return NULL; - a.zbuffer = (stbi_uc *)buffer; - a.zbuffer_end = (stbi_uc *)buffer + len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) - *outlen = (int)(a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(16384); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer+len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } } -STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen) { - stbi__zbuf a; - a.zbuffer = (stbi_uc *)ibuffer; - a.zbuffer_end = (stbi_uc *)ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int)(a.zout - a.zout_start); - else - return -1; +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int) (a.zout - a.zout_start); + else + return -1; } #endif @@ -4796,1303 +4604,1131 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const cha // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct { - stbi__uint32 length; - stbi__uint32 type; +typedef struct +{ + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context * s) { - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) +{ + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context * s) { - static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; - int i; - for (i = 0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) - return stbi__err("bad png sig", "Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) +{ + static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; + int i; + for (i=0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); + return 1; } -typedef struct { - stbi__context * s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct +{ + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; + enum { - STBI__F_none = 0, - STBI__F_sub = 1, - STBI__F_up = 2, - STBI__F_avg = 3, - STBI__F_paeth = 4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none=0, + STBI__F_sub=1, + STBI__F_up=2, + STBI__F_avg=3, + STBI__F_paeth=4, + // synthetic filter used for first scanline to avoid needing a dummy row of 0s + STBI__F_avg_first }; -static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, STBI__F_avg_first, STBI__F_paeth_first}; +static stbi_uc first_row_filter[5] = +{ + STBI__F_none, + STBI__F_sub, + STBI__F_none, + STBI__F_avg_first, + STBI__F_sub // Paeth with b=c=0 turns out to be equivalent to sub +}; -static int stbi__paeth(int a, int b, int c) { - int p = a + b - c; - int pa = abs(p - a); - int pb = abs(p - b); - int pc = abs(p - c); - if (pa <= pb && pa <= pc) - return a; - if (pb <= pc) - return b; - return c; +static int stbi__paeth(int a, int b, int c) +{ + // This formulation looks very different from the reference in the PNG spec, but is + // actually equivalent and has favorable data dependencies and admits straightforward + // generation of branch-free code, which helps performance significantly. + int thresh = c*3 - (a + b); + int lo = a < b ? a : b; + int hi = a < b ? b : a; + int t0 = (hi <= thresh) ? lo : c; + int t1 = (thresh <= lo) ? hi : t0; + return t1; } -static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, 0, 0, 0, 0x01}; +static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; + +// adds an extra all-255 alpha channel +// dest == src is legal +// img_n must be 1 or 3 +static void stbi__create_png_alpha_expand8(stbi_uc *dest, stbi_uc *src, stbi__uint32 x, int img_n) +{ + int i; + // must process data backwards since we allow dest==src + if (img_n == 1) { + for (i=x-1; i >= 0; --i) { + dest[i*2+1] = 255; + dest[i*2+0] = src[i]; + } + } else { + STBI_ASSERT(img_n == 3); + for (i=x-1; i >= 0; --i) { + dest[i*4+3] = 255; + dest[i*4+2] = src[i*3+2]; + dest[i*4+1] = src[i*3+1]; + dest[i*4+0] = src[i*3+0]; + } + } +} // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, - stbi__uint32 y, int depth, int color) { - int bytes = (depth == 16 ? 2 : 1); - stbi__context * s = a->s; - stbi__uint32 i, j, stride = x * out_n * bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) +{ + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i,j,stride = x*out_n*bytes; + stbi__uint32 img_len, img_width_bytes; + stbi_uc *filter_buf; + int all_ok = 1; + int k; + int img_n = s->img_n; // copy it into a local for later - int output_bytes = out_n * bytes; - int filter_bytes = img_n * bytes; - int width = x; + int output_bytes = out_n*bytes; + int filter_bytes = img_n*bytes; + int width = x; - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); - a->out = (stbi_uc *)stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) - return stbi__err("outofmem", "Out of memory"); + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); + a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) return stbi__err("outofmem", "Out of memory"); - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) - return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; + // note: error exits here don't need to clean up a->out individually, + // stbi__do_png always does on error. + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + if (!stbi__mad2sizes_valid(img_width_bytes, y, img_width_bytes)) return stbi__err("too large", "Corrupt PNG"); + img_len = (img_width_bytes + 1) * y; - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) - return stbi__err("not enough pixels", "Corrupt PNG"); + // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, + // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), + // so just check for raw_len < img_len always. + if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * prior; - int filter = *raw++; + // Allocate two scan lines worth of filter workspace buffer. + filter_buf = (stbi_uc *) stbi__malloc_mad2(img_width_bytes, 2, 0); + if (!filter_buf) return stbi__err("outofmem", "Out of memory"); - if (filter > 4) - return stbi__err("invalid filter", "Corrupt PNG"); + // Filtering for low-bit-depth images + if (depth < 8) { + filter_bytes = 1; + width = img_width_bytes; + } - if (depth < 8) { - if (img_width_bytes > x) - return stbi__err("invalid width", "Corrupt PNG"); - cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + for (j=0; j < y; ++j) { + // cur/prior filter buffers alternate + stbi_uc *cur = filter_buf + (j & 1)*img_width_bytes; + stbi_uc *prior = filter_buf + (~j & 1)*img_width_bytes; + stbi_uc *dest = a->out + stride*j; + int nk = width * filter_bytes; + int filter = *raw++; - // if first row, use special filter that doesn't sample previous row - if (j == 0) - filter = first_row_filter[filter]; + // check filter type + if (filter > 4) { + all_ok = stbi__err("invalid filter","Corrupt PNG"); + break; + } - // handle first byte explicitly - for (k = 0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none: - cur[k] = raw[k]; - break; - case STBI__F_sub: - cur[k] = raw[k]; - break; - case STBI__F_up: - cur[k] = STBI__BYTECAST(raw[k] + prior[k]); - break; - case STBI__F_avg: - cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); - break; - case STBI__F_paeth: - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); - break; - case STBI__F_avg_first: - cur[k] = raw[k]; - break; - case STBI__F_paeth_first: - cur[k] = raw[k]; - break; + // if first row, use special filter that doesn't sample previous row + if (j == 0) filter = first_row_filter[filter]; + + // perform actual filtering + switch (filter) { + case STBI__F_none: + memcpy(cur, raw, nk); + break; + case STBI__F_sub: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); + break; + case STBI__F_up: + for (k = 0; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); + break; + case STBI__F_paeth: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); // prior[k] == stbi__paeth(0,prior[k],0) + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes], prior[k], prior[k-filter_bytes])); + break; + case STBI__F_avg_first: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); + break; + } + + raw += nk; + + // expand decoded bits in cur to dest, also adding an extra alpha channel if desired + if (depth < 8) { + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + stbi_uc *in = cur; + stbi_uc *out = dest; + stbi_uc inb = 0; + stbi__uint32 nsmp = x*img_n; + + // expand bits to bytes first + if (depth == 4) { + for (i=0; i < nsmp; ++i) { + if ((i & 1) == 0) inb = *in++; + *out++ = scale * (inb >> 4); + inb <<= 4; } - } - - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes + 1] = 255; // first pixel bottom byte + } else if (depth == 2) { + for (i=0; i < nsmp; ++i) { + if ((i & 3) == 0) inb = *in++; + *out++ = scale * (inb >> 6); + inb <<= 2; } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; - } - - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1) * filter_bytes; -#define STBI__CASE(f) \ - case f: \ - for (k = 0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: - memcpy(cur, raw, nk); - break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); } - break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } - break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } - break; - STBI__CASE(STBI__F_paeth) { - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); - } - break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); } - break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); } - break; + } else { + STBI_ASSERT(depth == 1); + for (i=0; i < nsmp; ++i) { + if ((i & 7) == 0) inb = *in++; + *out++ = scale * (inb >> 7); + inb <<= 1; } -#undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n + 1 == out_n); -#define STBI__CASE(f) \ - case f: \ - for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, cur += output_bytes, prior += output_bytes) \ - for (k = 0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } - break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); } - break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } - break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } - break; - STBI__CASE(STBI__F_paeth) { - cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); - } - break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); } - break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); } - break; + } + + // insert alpha=255 values if desired + if (img_n != out_n) + stbi__create_png_alpha_expand8(dest, dest, x, img_n); + } else if (depth == 8) { + if (img_n == out_n) + memcpy(dest, cur, x*img_n); + else + stbi__create_png_alpha_expand8(dest, cur, x, img_n); + } else if (depth == 16) { + // convert the image data from big-endian to platform-native + stbi__uint16 *dest16 = (stbi__uint16*)dest; + stbi__uint32 nsmp = x*img_n; + + if (img_n == out_n) { + for (i = 0; i < nsmp; ++i, ++dest16, cur += 2) + *dest16 = (cur[0] << 8) | cur[1]; + } else { + STBI_ASSERT(img_n+1 == out_n); + if (img_n == 1) { + for (i = 0; i < x; ++i, dest16 += 2, cur += 2) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = 0xffff; + } + } else { + STBI_ASSERT(img_n == 3); + for (i = 0; i < x; ++i, dest16 += 4, cur += 6) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = (cur[2] << 8) | cur[3]; + dest16[2] = (cur[4] << 8) | cur[5]; + dest16[3] = 0xffff; + } } -#undef STBI__CASE + } + } + } - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride * j; // start at the beginning of the row again - for (i = 0; i < x; ++i, cur += output_bytes) { - cur[filter_bytes + 1] = 255; - } - } - } - } + STBI_FREE(filter_buf); + if (!all_ok) return 0; - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j = 0; j < y; ++j) { - stbi_uc * cur = a->out + stride * j; - stbi_uc * in = a->out + stride * j + x * out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for - // 1/2/4-bit png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that - // will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k = x * img_n; k >= 2; k -= 2, ++in) { - *cur++ = scale * ((*in >> 4)); - *cur++ = scale * ((*in) & 0x0f); - } - if (k > 0) - *cur++ = scale * ((*in >> 4)); - } else if (depth == 2) { - for (k = x * img_n; k >= 4; k -= 4, ++in) { - *cur++ = scale * ((*in >> 6)); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in) & 0x03); - } - if (k > 0) - *cur++ = scale * ((*in >> 6)); - if (k > 1) - *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) - *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k = x * img_n; k >= 8; k -= 8, ++in) { - *cur++ = scale * ((*in >> 7)); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in) & 0x01); - } - if (k > 0) - *cur++ = scale * ((*in >> 7)); - if (k > 1) - *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) - *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) - *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) - *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) - *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) - *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride * j; - if (img_n == 1) { - for (q = x - 1; q >= 0; --q) { - cur[q * 2 + 1] = 255; - cur[q * 2 + 0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q = x - 1; q >= 0; --q) { - cur[q * 4 + 3] = 255; - cur[q * 4 + 2] = cur[q * 3 + 2]; - cur[q * 4 + 1] = cur[q * 3 + 1]; - cur[q * 4 + 0] = cur[q * 3 + 0]; - } - } - } - } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc * cur = a->out; - stbi__uint16 * cur16 = (stbi__uint16 *)cur; - - for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { - *cur16 = (cur[0] << 8) | cur[1]; - } - } - - return 1; + return 1; } -static int stbi__create_png_image(stbi__png * a, stbi_uc * image_data, stbi__uint32 image_data_len, int out_n, int depth, - int color, int interlaced) { - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc * final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) +{ + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - // de-interlacing - final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - if (!final) - return stbi__err("outofmem", "Out of memory"); - for (p = 0; p < 7; ++p) { - int xorig[] = {0, 4, 0, 2, 0, 1, 0}; - int yorig[] = {0, 0, 4, 0, 2, 0, 1}; - int xspc[] = {8, 8, 4, 4, 2, 2, 1}; - int yspc[] = {8, 8, 8, 4, 4, 2, 2}; - int i, j, x, y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; + // de-interlacing + final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + if (!final) return stbi__err("outofmem", "Out of memory"); + for (p=0; p < 7; ++p) { + int xorig[] = { 0,4,0,2,0,1,0 }; + int yorig[] = { 0,0,4,0,2,0,1 }; + int xspc[] = { 8,8,4,4,2,2,1 }; + int yspc[] = { 8,8,8,4,4,2,2 }; + int i,j,x,y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j=0; j < y; ++j) { + for (i=0; i < x; ++i) { + int out_y = j*yspc[p]+yorig[p]; + int out_x = i*xspc[p]+xorig[p]; + memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, + a->out + (j*x+i)*out_bytes, out_bytes); } - for (j = 0; j < y; ++j) { - for (i = 0; i < x; ++i) { - int out_y = j * yspc[p] + yorig[p]; - int out_x = i * xspc[p] + xorig[p]; - memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, a->out + (j * x + i) * out_bytes, - out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__compute_transparency(stbi__png * z, stbi_uc tc[3], int out_n) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i=0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i=0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png * z, stbi__uint16 tc[3], int out_n) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 * p = (stbi__uint16 *)z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16*) z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png * a, stbi_uc * palette, int len, int pal_img_n) { - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) +{ + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; - p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) - return stbi__err("outofmem", "Out of memory"); + p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) return stbi__err("outofmem", "Out of memory"); - // between here and free(out) below, exitting would leak - temp_out = p; + // between here and free(out) below, exitting would leak + temp_out = p; - if (pal_img_n == 3) { - for (i = 0; i < pixel_count; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p += 3; - } - } else { - for (i = 0; i < pixel_count; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p[3] = palette[n + 3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; + if (pal_img_n == 3) { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p += 3; + } + } else { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p[3] = palette[n+3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; - STBI_NOTUSED(len); + STBI_NOTUSED(len); - return 1; + return 1; } static int stbi__unpremultiply_on_load_global = 0; static int stbi__de_iphone_flag_global = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { - stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { - stbi__de_iphone_flag_global = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_global = flag_true_if_should_convert; } #ifndef STBI_THREAD_LOCAL -#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global -#define stbi__de_iphone_flag stbi__de_iphone_flag_global +#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global +#define stbi__de_iphone_flag stbi__de_iphone_flag_global #else static STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set; static STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set; -STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) { - stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; - stbi__unpremultiply_on_load_set = 1; +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; + stbi__unpremultiply_on_load_set = 1; } -STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) { - stbi__de_iphone_flag_local = flag_true_if_should_convert; - stbi__de_iphone_flag_set = 1; +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_local = flag_true_if_should_convert; + stbi__de_iphone_flag_set = 1; } -#define stbi__unpremultiply_on_load \ - (stbi__unpremultiply_on_load_set ? stbi__unpremultiply_on_load_local : stbi__unpremultiply_on_load_global) -#define stbi__de_iphone_flag (stbi__de_iphone_flag_set ? stbi__de_iphone_flag_local : stbi__de_iphone_flag_global) +#define stbi__unpremultiply_on_load (stbi__unpremultiply_on_load_set \ + ? stbi__unpremultiply_on_load_local \ + : stbi__unpremultiply_on_load_global) +#define stbi__de_iphone_flag (stbi__de_iphone_flag_set \ + ? stbi__de_iphone_flag_local \ + : stbi__de_iphone_flag_global) #endif // STBI_THREAD_LOCAL -static void stbi__de_iphone(stbi__png * z) { - stbi__context * s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc * p = z->out; +static void stbi__de_iphone(stbi__png *z) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - if (s->img_out_n == 3) { // convert bgr to rgb - for (i = 0; i < pixel_count; ++i) { + if (s->img_out_n == 3) { // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i=0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = ( t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; + } + } else { + // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { stbi_uc t = p[0]; p[0] = p[2]; p[2] = t; - p += 3; - } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i = 0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = (t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } - } else { - // convert bgr to rgb - for (i = 0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } - } - } + p += 4; + } + } + } } -#define STBI__PNG_TYPE(a, b, c, d) (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + (unsigned)(d)) +#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) -static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { - stbi_uc palette[1024], pal_img_n = 0; - stbi_uc has_trans = 0, tc[3] = {0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; - int first = 1, k, interlace = 0, color = 0, is_iphone = 0; - stbi__context * s = z->s; +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) +{ + stbi_uc palette[1024], pal_img_n=0; + stbi_uc has_trans=0, tc[3]={0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; + int first=1,k,interlace=0, color=0, is_iphone=0; + stbi__context *s = z->s; - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; - if (!stbi__check_png_header(s)) - return 0; + if (!stbi__check_png_header(s)) return 0; - if (scan == STBI__SCAN_type) - return 1; + if (scan == STBI__SCAN_type) return 1; - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C','g','B','I'): is_iphone = 1; stbi__skip(s, c.length); break; - case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { - int comp, filter; - if (!first) - return stbi__err("multiple IHDR", "Corrupt PNG"); + case STBI__PNG_TYPE('I','H','D','R'): { + int comp,filter; + if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); first = 0; - if (c.length != 13) - return stbi__err("bad IHDR len", "Corrupt PNG"); + if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); s->img_x = stbi__get32be(s); s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - z->depth = stbi__get8(s); - if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) - return stbi__err("1/2/4/8/16-bit only", "PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); - if (color > 6) - return stbi__err("bad ctype", "Corrupt PNG"); - if (color == 3 && z->depth == 16) - return stbi__err("bad ctype", "Corrupt PNG"); - if (color == 3) - pal_img_n = 3; - else if (color & 1) - return stbi__err("bad ctype", "Corrupt PNG"); - comp = stbi__get8(s); - if (comp) - return stbi__err("bad comp method", "Corrupt PNG"); - filter = stbi__get8(s); - if (filter) - return stbi__err("bad filter method", "Corrupt PNG"); - interlace = stbi__get8(s); - if (interlace > 1) - return stbi__err("bad interlace method", "Corrupt PNG"); - if (!s->img_x || !s->img_y) - return stbi__err("0-pixel image", "Corrupt PNG"); + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); + comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); + filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); + interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); + if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) - return stbi__err("too large", "Image too large to decode"); + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) - return stbi__err("too large", "Corrupt PNG"); + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); } // even with SCAN_header, have to scan to see if we have a tRNS break; - } + } - case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256 * 3) - return stbi__err("invalid PLTE", "Corrupt PNG"); + case STBI__PNG_TYPE('P','L','T','E'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); pal_len = c.length / 3; - if (pal_len * 3 != c.length) - return stbi__err("invalid PLTE", "Corrupt PNG"); - for (i = 0; i < pal_len; ++i) { - palette[i * 4 + 0] = stbi__get8(s); - palette[i * 4 + 1] = stbi__get8(s); - palette[i * 4 + 2] = stbi__get8(s); - palette[i * 4 + 3] = 255; + if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); + for (i=0; i < pal_len; ++i) { + palette[i*4+0] = stbi__get8(s); + palette[i*4+1] = stbi__get8(s); + palette[i*4+2] = stbi__get8(s); + palette[i*4+3] = 255; } break; - } + } - case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) - return stbi__err("tRNS after IDAT", "Corrupt PNG"); + case STBI__PNG_TYPE('t','R','N','S'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); if (pal_img_n) { - if (scan == STBI__SCAN_header) { - s->img_n = 4; - return 1; - } - if (pal_len == 0) - return stbi__err("tRNS before PLTE", "Corrupt PNG"); - if (c.length > pal_len) - return stbi__err("bad tRNS len", "Corrupt PNG"); - pal_img_n = 4; - for (i = 0; i < c.length; ++i) - palette[i * 4 + 3] = stbi__get8(s); + if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } + if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); + if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); + pal_img_n = 4; + for (i=0; i < c.length; ++i) + palette[i*4+3] = stbi__get8(s); } else { - if (!(s->img_n & 1)) - return stbi__err("tRNS with alpha", "Corrupt PNG"); - if (c.length != (stbi__uint32)s->img_n * 2) - return stbi__err("bad tRNS len", "Corrupt PNG"); - has_trans = 1; - // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. - if (scan == STBI__SCAN_header) { - ++s->img_n; - return 1; - } - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) - tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) - tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * - stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } + if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); + if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); + has_trans = 1; + // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. + if (scan == STBI__SCAN_header) { ++s->img_n; return 1; } + if (z->depth == 16) { + for (k = 0; k < s->img_n && k < 3; ++k) // extra loop test to suppress false GCC warning + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n && k < 3; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + } } break; - } + } - case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) - return stbi__err("no PLTE", "Corrupt PNG"); + case STBI__PNG_TYPE('I','D','A','T'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); if (scan == STBI__SCAN_header) { - // header scan definitely stops at first IDAT - if (pal_img_n) - s->img_n = pal_img_n; - return 1; + // header scan definitely stops at first IDAT + if (pal_img_n) + s->img_n = pal_img_n; + return 1; } - if (c.length > (1u << 30)) - return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); - if ((int)(ioff + c.length) < (int)ioff) - return 0; + if (c.length > (1u << 30)) return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); + if ((int)(ioff + c.length) < (int)ioff) return 0; if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc * p; - if (idata_limit == 0) - idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); - if (p == NULL) - return stbi__err("outofmem", "Out of memory"); - z->idata = p; + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); + z->idata = p; } - if (!stbi__getn(s, z->idata + ioff, c.length)) - return stbi__err("outofdata", "Corrupt PNG"); + if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); ioff += c.length; break; - } + } - case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + case STBI__PNG_TYPE('I','E','N','D'): { stbi__uint32 raw_len, bpl; - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) - return 1; - if (z->idata == NULL) - return stbi__err("no IDAT", "Corrupt PNG"); + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) return 1; + if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); // initial guess for decoded data size to avoid unnecessary reallocs bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag((char *)z->idata, ioff, raw_len, - (int *)&raw_len, !is_iphone); - if (z->expanded == NULL) - return 0; // zlib should set error - STBI_FREE(z->idata); - z->idata = NULL; - if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n + 1; + z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); + if (z->expanded == NULL) return 0; // zlib should set error + STBI_FREE(z->idata); z->idata = NULL; + if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) + s->img_out_n = s->img_n+1; else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) - return 0; + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) - return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) - return 0; - } + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; + } } if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); + stbi__de_iphone(z); if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) - s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; } - STBI_FREE(z->expanded); - z->expanded = NULL; + STBI_FREE(z->expanded); z->expanded = NULL; // end of PNG chunk, read and skip CRC stbi__get32be(s); return 1; - } + } - default: + default: // if critical, fail - if (first) - return stbi__err("first not IHDR", "Corrupt PNG"); + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); if ((c.type & (1 << 29)) == 0) { -#ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); -#endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); + #ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); + #endif + return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); } stbi__skip(s, c.length); break; - } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } } -static void * stbi__do_png(stbi__png * p, int * x, int * y, int * n, int req_comp, stbi__result_info * ri) { - void * result = NULL; - if (req_comp < 0 || req_comp > 4) - return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) - return result; - } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) - *n = p->s->img_n; - } - STBI_FREE(p->out); - p->out = NULL; - STBI_FREE(p->expanded); - p->expanded = NULL; - STBI_FREE(p->idata); - p->idata = NULL; +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) +{ + void *result=NULL; + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) *n = p->s->img_n; + } + STBI_FREE(p->out); p->out = NULL; + STBI_FREE(p->expanded); p->expanded = NULL; + STBI_FREE(p->idata); p->idata = NULL; - return result; + return result; } -static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi__png p; - p.s = s; - return stbi__do_png(&p, x, y, comp, req_comp, ri); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi__png p; + p.s = s; + return stbi__do_png(&p, x,y,comp,req_comp, ri); } -static int stbi__png_test(stbi__context * s) { - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; +static int stbi__png_test(stbi__context *s) +{ + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; } -static int stbi__png_info_raw(stbi__png * p, int * x, int * y, int * comp) { - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind(p->s); - return 0; - } - if (x) - *x = p->s->img_x; - if (y) - *y = p->s->img_y; - if (comp) - *comp = p->s->img_n; - return 1; +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) +{ + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind( p->s ); + return 0; + } + if (x) *x = p->s->img_x; + if (y) *y = p->s->img_y; + if (comp) *comp = p->s->img_n; + return 1; } -static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp) { - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); } -static int stbi__png_is16(stbi__context * s) { - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; +static int stbi__png_is16(stbi__context *s) +{ + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context * s) { - int r; - int sz; - if (stbi__get8(s) != 'B') - return 0; - if (stbi__get8(s) != 'M') - return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; +static int stbi__bmp_test_raw(stbi__context *s) +{ + int r; + int sz; + if (stbi__get8(s) != 'B') return 0; + if (stbi__get8(s) != 'M') return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; } -static int stbi__bmp_test(stbi__context * s) { - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test(stbi__context *s) +{ + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } + // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) { - int n = 0; - if (z == 0) - return -1; - if (z >= 0x10000) { - n += 16; - z >>= 16; - } - if (z >= 0x00100) { - n += 8; - z >>= 8; - } - if (z >= 0x00010) { - n += 4; - z >>= 4; - } - if (z >= 0x00004) { - n += 2; - z >>= 2; - } - if (z >= 0x00002) { - n += 1; /* >>= 1;*/ - } - return n; +static int stbi__high_bit(unsigned int z) +{ + int n=0; + if (z == 0) return -1; + if (z >= 0x10000) { n += 16; z >>= 16; } + if (z >= 0x00100) { n += 8; z >>= 8; } + if (z >= 0x00010) { n += 4; z >>= 4; } + if (z >= 0x00004) { n += 2; z >>= 2; } + if (z >= 0x00002) { n += 1;/* >>= 1;*/ } + return n; } -static int stbi__bitcount(unsigned int a) { - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) +{ + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) { - static unsigned int mul_table[9] = { - 0, - 0xff /*0b11111111*/, - 0x55 /*0b01010101*/, - 0x49 /*0b01001001*/, - 0x11 /*0b00010001*/, - 0x21 /*0b00100001*/, - 0x41 /*0b01000001*/, - 0x81 /*0b10000001*/, - 0x01 /*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0, 0, 1, 0, 2, 4, 6, 0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8 - bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +static int stbi__shiftsigned(unsigned int v, int shift, int bits) +{ + static unsigned int mul_table[9] = { + 0, + 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, + 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0,0,1,0,2,4,6,0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8-bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; } -typedef struct { - int bpp, offset, hsz; - unsigned int mr, mg, mb, ma, all_a; - int extra_read; +typedef struct +{ + int bpp, offset, hsz; + unsigned int mr,mg,mb,ma, all_a; + int extra_read; } stbi__bmp_data; -static int stbi__bmp_set_mask_defaults(stbi__bmp_data * info, int compress) { - // BI_BITFIELDS specifies masks explicitly, don't override - if (compress == 3) - return 1; +static int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress) +{ + // BI_BITFIELDS specifies masks explicitly, don't override + if (compress == 3) + return 1; - if (compress == 0) { - if (info->bpp == 16) { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } else if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - // otherwise, use defaults, which is all-0 - info->mr = info->mg = info->mb = info->ma = 0; - } - return 1; - } - return 0; // error + if (compress == 0) { + if (info->bpp == 16) { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } else if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + } else { + // otherwise, use defaults, which is all-0 + info->mr = info->mg = info->mb = info->ma = 0; + } + return 1; + } + return 0; // error } -static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') - return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) +{ + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; - if (info->offset < 0) - return stbi__errpuc("bad BMP", "bad BMP"); + if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) - return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) - return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) - return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - if (compress >= 4) - return stbi__errpuc("BMP JPEG/PNG", - "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes - if (compress == 3 && info->bpp != 16 && info->bpp != 32) - return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - stbi__bmp_set_mask_defaults(info, compress); - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - // V4/V5 header - int i; - if (hsz != 108 && hsz != 124) - return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs - stbi__bmp_set_mask_defaults(info, compress); - stbi__get32le(s); // discard color space - for (i = 0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } - } - } - return (void *)1; + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + if (compress >= 4) return stbi__errpuc("BMP JPEG/PNG", "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + if (compress == 3 && info->bpp != 16 && info->bpp != 32) return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + stbi__bmp_set_mask_defaults(info, compress); + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else { + // V4/V5 header + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + stbi__bmp_set_mask_defaults(info, compress); + stbi__get32le(s); // discard color space + for (i=0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved + } + } + } + return (void *) 1; } -static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; - unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; - stbi_uc pal[256][4]; - int psize = 0, i, j, width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + unsigned int mr=0,mg=0,mb=0,ma=0, all_a; + stbi_uc pal[256][4]; + int psize=0,i,j,width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); - flip_vertically = ((int)s->img_y) > 0; - s->img_y = abs((int)s->img_y); + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + flip_vertically = ((int) s->img_y) > 0; + s->img_y = abs((int) s->img_y); - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - // accept some number of extra bytes after the header, but if the offset points either to before - // the header ends or implies a large amount of extra data, reject the file as malformed - int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); - int header_limit = 1024; // max we actually read is below 256 bytes currently. - int extra_data_limit = 256 * 4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. - if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { - return stbi__errpuc("bad header", "Corrupt BMP"); - } - // we established that bytes_read_so_far is positive and sensible. - // the first half of this test rejects offsets that are either too small positives, or - // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn - // ensures the number computed in the second half of the test can't overflow. - if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { - return stbi__errpuc("bad offset", "Corrupt BMP"); - } else { - stbi__skip(s, info.offset - bytes_read_so_far); - } - } + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + // accept some number of extra bytes after the header, but if the offset points either to before + // the header ends or implies a large amount of extra data, reject the file as malformed + int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256*4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { + return stbi__errpuc("bad header", "Corrupt BMP"); + } + // we established that bytes_read_so_far is positive and sensible. + // the first half of this test rejects offsets that are either too small positives, or + // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn + // ensures the number computed in the second half of the test can't overflow. + if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + stbi__skip(s, info.offset - bytes_read_so_far); + } + } - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert - out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z = 0; - if (psize == 0 || psize > 256) { - STBI_FREE(out); - return stbi__errpuc("invalid", "Corrupt BMP"); - } - for (i = 0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) - stbi__get8(s); - pal[i][3] = 255; - } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) - width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) - width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) - width = s->img_x; - else { - STBI_FREE(out); - return stbi__errpuc("bad bpp", "Corrupt BMP"); - } - pad = (-width) & 3; - if (info.bpp == 1) { - for (j = 0; j < (int)s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i = 0; i < (int)s->img_x; ++i) { - int color = (v >> bit_offset) & 0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) - out[z++] = 255; - if (i + 1 == (int)s->img_x) - break; - if ((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j = 0; j < (int)s->img_y; ++j) { - for (i = 0; i < (int)s->img_x; i += 2) { - int v = stbi__get8(s), v2 = 0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) - out[z++] = 255; - if (i + 1 == (int)s->img_x) - break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) - out[z++] = 255; - } - stbi__skip(s, pad); - } - } - } else { - int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, bcount = 0, acount = 0; - int z = 0; - int easy = 0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) - width = 3 * s->img_x; - else if (info.bpp == 16) - width = 2 * s->img_x; - else /* bpp = 32 and pad = 0 */ - width = 0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; - } - if (!easy) { - if (!mr || !mg || !mb) { - STBI_FREE(out); - return stbi__errpuc("bad masks", "Corrupt BMP"); - } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr) - 7; - rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg) - 7; - gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb) - 7; - bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma) - 7; - acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { - STBI_FREE(out); - return stbi__errpuc("bad masks", "Corrupt BMP"); - } - } - for (j = 0; j < (int)s->img_y; ++j) { - if (easy) { - for (i = 0; i < (int)s->img_x; ++i) { - unsigned char a; - out[z + 2] = stbi__get8(s); - out[z + 1] = stbi__get8(s); - out[z + 0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) - out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i = 0; i < (int)s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) - out[z++] = STBI__BYTECAST(a); - } + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z=0; + if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } + for (i=0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) width = s->img_x; + else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } + pad = (-width)&3; + if (info.bpp == 1) { + for (j=0; j < (int) s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i=0; i < (int) s->img_x; ++i) { + int color = (v>>bit_offset)&0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + if((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } } stbi__skip(s, pad); - } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j = 0; j < (int)s->img_y >> 1; ++j) { - stbi_uc * p1 = out + j * s->img_x * target; - stbi_uc * p2 = out + (s->img_y - 1 - j) * s->img_x * target; - for (i = 0; i < (int)s->img_x * target; ++i) { - t = p1[i]; - p1[i] = p2[i]; - p2[i] = t; + } + } else { + for (j=0; j < (int) s->img_y; ++j) { + for (i=0; i < (int) s->img_x; i += 2) { + int v=stbi__get8(s),v2=0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; } - } - } + stbi__skip(s, pad); + } + } + } else { + int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; + int z = 0; + int easy=0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) width = 3 * s->img_x; + else if (info.bpp == 16) width = 2*s->img_x; + else /* bpp = 32 and pad = 0 */ width=0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + } + for (j=0; j < (int) s->img_y; ++j) { + if (easy) { + for (i=0; i < (int) s->img_x; ++i) { + unsigned char a; + out[z+2] = stbi__get8(s); + out[z+1] = stbi__get8(s); + out[z+0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i=0; i < (int) s->img_x; ++i) { + stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) out[z++] = STBI__BYTECAST(a); + } + } + stbi__skip(s, pad); + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) + out[i] = 255; - *x = s->img_x; - *y = s->img_y; - if (comp) - *comp = s->img_n; - return out; + if (flip_vertically) { + stbi_uc t; + for (j=0; j < (int) s->img_y>>1; ++j) { + stbi_uc *p1 = out + j *s->img_x*target; + stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; + for (i=0; i < (int) s->img_x*target; ++i) { + t = p1[i]; p1[i] = p2[i]; p2[i] = t; + } + } + } + + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + return out; } #endif @@ -6100,74 +5736,68 @@ static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, in // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) - *is_rgb16 = 0; - switch (bits_per_pixel) { - case 8: - return STBI_grey; - case 16: - if (is_grey) - return STBI_grey_alpha; - // fallthrough - case 15: - if (is_rgb16) - *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: - return bits_per_pixel / 8; - default: - return 0; - } +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) +{ + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) *is_rgb16 = 0; + switch(bits_per_pixel) { + case 8: return STBI_grey; + case 16: if(is_grey) return STBI_grey_alpha; + // fallthrough + case 15: if(is_rgb16) *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: return bits_per_pixel/8; + default: return 0; + } } -static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) +{ int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; int sz, tga_colormap_type; - stbi__get8(s); // discard Offset + stbi__get8(s); // discard Offset tga_colormap_type = stbi__get8(s); // colormap type - if (tga_colormap_type > 1) { + if( tga_colormap_type > 1 ) { stbi__rewind(s); - return 0; // only RGB or indexed allowed + return 0; // only RGB or indexed allowed } tga_image_type = stbi__get8(s); // image type - if (tga_colormap_type == 1) { // colormapped (paletted) image + if ( tga_colormap_type == 1 ) { // colormapped (paletted) image if (tga_image_type != 1 && tga_image_type != 9) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { stbi__rewind(s); return 0; } - stbi__skip(s, 4); // skip image x and y origin + stbi__skip(s,4); // skip image x and y origin tga_colormap_bpp = sz; } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ((tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11)) { + if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { stbi__rewind(s); return 0; // only RGB or grey allowed, +/- RLE } - stbi__skip(s, 9); // skip colormap specification and image x/y origin + stbi__skip(s,9); // skip colormap specification and image x/y origin tga_colormap_bpp = 0; } tga_w = stbi__get16le(s); - if (tga_w < 1) { + if( tga_w < 1 ) { stbi__rewind(s); - return 0; // test width + return 0; // test width } tga_h = stbi__get16le(s); - if (tga_h < 1) { + if( tga_h < 1 ) { stbi__rewind(s); - return 0; // test height + return 0; // test height } tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits + stbi__get8(s); // ignore alpha bits if (tga_colormap_bpp != 0) { - if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { // when using a colormap, tga_bits_per_pixel is the size of the indexes // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); @@ -6177,268 +5807,270 @@ static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { } else { tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); } - if (!tga_comp) { - stbi__rewind(s); - return 0; + if(!tga_comp) { + stbi__rewind(s); + return 0; } - if (x) - *x = tga_w; - if (y) - *y = tga_h; - if (comp) - *comp = tga_comp; - return 1; // seems to have passed everything + if (x) *x = tga_w; + if (y) *y = tga_h; + if (comp) *comp = tga_comp; + return 1; // seems to have passed everything } -static int stbi__tga_test(stbi__context * s) { - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if (tga_color_type > 1) - goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if (tga_color_type == 1) { // colormapped (paletted) image - if (sz != 1 && sz != 9) - goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s, 4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) - goto errorEnd; - stbi__skip(s, 4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) - goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s, 9); // skip colormap specification and image x/y origin - } - if (stbi__get16le(s) < 1) - goto errorEnd; // test width - if (stbi__get16le(s) < 1) - goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) - goto errorEnd; // for colormapped images, bpp is size of an index - if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) - goto errorEnd; +static int stbi__tga_test(stbi__context *s) +{ + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if ( tga_color_type == 1 ) { // colormapped (paletted) image + if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + stbi__skip(s,4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s,9); // skip colormap specification and image x/y origin + } + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context * s, stbi_uc * out) { - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255) / 31); - out[1] = (stbi_uc)((g * 255) / 31); - out[2] = (stbi_uc)((b * 255) / 31); +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) +{ + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later + out[0] = (stbi_uc)((r * 255)/31); + out[1] = (stbi_uc)((g * 255)/31); + out[2] = (stbi_uc)((b * 255)/31); - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. } -static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16 = 0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char * tga_data; - unsigned char * tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16=0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO - if (tga_height > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - // do a tiny bit of precessing - if (tga_image_type >= 8) { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); + // do a tiny bit of precessing + if ( tga_image_type >= 8 ) + { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); - // If I'm paletted, then I'll use the number of bits from the palette - if (tga_indexed) - tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); + // If I'm paletted, then I'll use the number of bits from the palette + if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - // tga info - *x = tga_width; - *y = tga_height; - if (comp) - *comp = tga_comp; + // tga info + *x = tga_width; + *y = tga_height; + if (comp) *comp = tga_comp; - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); - tga_data = (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) - return stbi__errpuc("outofmem", "Out of memory"); + tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset); + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset ); - if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { - for (i = 0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height - i - 1 : i; - stbi_uc * tga_row = tga_data + row * tga_width * tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if (tga_indexed) { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); + if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { + for (i=0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height -i - 1 : i; + stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if ( tga_indexed) + { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start ); + // load the palette + tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); + } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i=0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start); - // load the palette - tga_palette = (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + } + // load the data + for (i=0; i < tga_width * tga_height; ++i) + { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if ( tga_is_RLE ) + { + if ( RLE_count == 0 ) + { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if ( !RLE_repeating ) + { + read_next_pixel = 1; } - if (tga_rgb16) { - stbi_uc * pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i = 0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - } - // load the data - for (i = 0; i < tga_width * tga_height; ++i) { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if (tga_is_RLE) { - if (RLE_count == 0) { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if (!RLE_repeating) { - read_next_pixel = 1; - } + } else + { + read_next_pixel = 1; + } + // OK, if I need to read a pixel, do it now + if ( read_next_pixel ) + { + // load however much data we did have + if ( tga_indexed ) + { + // read in index, then perform the lookup + int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if ( pal_idx >= tga_palette_len ) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx+j]; + } + } else if(tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); } else { - read_next_pixel = 1; + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } } - // OK, if I need to read a pixel, do it now - if (read_next_pixel) { - // load however much data we did have - if (tga_indexed) { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if (pal_idx >= tga_palette_len) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx + j]; - } - } else if (tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i * tga_comp + j] = raw_data[j]; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; - } - // do I need to invert the image? - if (tga_inverted) { - for (j = 0; j * 2 < tga_height; ++j) { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if ( tga_inverted ) + { + for (j = 0; j*2 < tga_height; ++j) + { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) + { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; } - } - // clear my palette, if I had one - if (tga_palette != NULL) { - STBI_FREE(tga_palette); - } - } + } + } + // clear my palette, if I had one + if ( tga_palette != NULL ) + { + STBI_FREE( tga_palette ); + } + } - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) { - unsigned char * tga_pixel = tga_data; - for (i = 0; i < tga_width * tga_height; ++i) { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; - } - } + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) + { + unsigned char* tga_pixel = tga_data; + for (i=0; i < tga_width * tga_height; ++i) + { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = + tga_x_origin = tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif @@ -6446,253 +6078,250 @@ static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, in // Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context * s) { - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; +static int stbi__psd_test(stbi__context *s) +{ + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; } -static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) { - int count, nleft, len; +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) +{ + int count, nleft, len; - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) - return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) - return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } - } - } + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; + } + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } - return 1; + return 1; } -static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w, h; - stbi_uc * out; - STBI_NOTUSED(ri); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w,h; + stbi_uc *out; + STBI_NOTUSED(ri); - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - // Skip 6 reserved bytes. - stbi__skip(s, 6); + // Skip 6 reserved bytes. + stbi__skip(s, 6 ); - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); - if (h > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s, stbi__get32be(s)); + // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) + stbi__skip(s,stbi__get32be(s) ); - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s)); + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s) ); - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s)); + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s) ); - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); - // Create the destination image. + // Create the destination image. - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *)stbi__malloc(4 * w * h); + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *) stbi__malloc(4 * w*h); - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w * h; + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w*h; - // Initialize the data to zero. - // memset( out, 0, pixelCount * 4 ); + // Initialize the data to zero. + //memset( out, 0, pixelCount * 4 ); - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. + // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. + // Else if n is 128, noop. + // Endloop - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2); + // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, + // which we're going to just skip. + stbi__skip(s, h * channelCount * 2 ); - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc * p; + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; - p = out + channel; - if (channel >= channelCount) { - // Fill this channel with default data. - for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); + p = out+channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } + + } else { + // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) + // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } + stbi_uc *p = out+channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; } - } - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc * p = out + channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16) stbi__get16be(s); } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 * q = ((stbi__uint16 *)out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16)stbi__get16be(s); - } else { - stbi_uc * p = out + channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc)(stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } + stbi_uc *p = out+channel; + if (bitdepth == 16) { // input bpc + for (i = 0; i < pixelCount; i++, p += 4) + *p = (stbi_uc) (stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } } - } - } + } + } + } - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i = 0; i < w * h; ++i) { - stbi__uint16 * pixel = (stbi__uint16 *)out + 4 * i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); - pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); - pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); - } + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i=0; i < w*h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); + pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); + pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); } - } else { - for (i = 0; i < w * h; ++i) { - unsigned char * pixel = out + 4 * i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); - pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); - pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); - } + } + } else { + for (i=0; i < w*h; ++i) { + unsigned char *pixel = out + 4*i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); + pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); + pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); } - } - } + } + } + } - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } - if (comp) - *comp = 4; - *y = h; - *x = w; + if (comp) *comp = 4; + *y = h; + *x = w; - return out; + return out; } #endif @@ -6704,221 +6333,216 @@ static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, in // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context * s, const char * str) { - int i; - for (i = 0; i < 4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s,const char *str) +{ + int i; + for (i=0; i<4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context * s) { - int i; +static int stbi__pic_test_core(stbi__context *s) +{ + int i; - if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) + return 0; - for (i = 0; i < 84; ++i) - stbi__get8(s); + for(i=0;i<84;++i) + stbi__get8(s); - if (!stbi__pic_is4(s, "PICT")) - return 0; + if (!stbi__pic_is4(s,"PICT")) + return 0; - return 1; + return 1; } -typedef struct { - stbi_uc size, type, channel; +typedef struct +{ + stbi_uc size,type,channel; } stbi__pic_packet; -static stbi_uc * stbi__readval(stbi__context * s, int channel, stbi_uc * dest) { - int mask = 0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) +{ + int mask=0x80, i; - for (i = 0; i < 4; ++i, mask >>= 1) { - if (channel & mask) { - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "PIC file too short"); - dest[i] = stbi__get8(s); - } - } + for (i=0; i<4; ++i, mask>>=1) { + if (channel & mask) { + if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); + dest[i]=stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel, stbi_uc * dest, const stbi_uc * src) { - int mask = 0x80, i; +static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) +{ + int mask=0x80,i; - for (i = 0; i < 4; ++i, mask >>= 1) - if (channel & mask) - dest[i] = src[i]; + for (i=0;i<4; ++i, mask>>=1) + if (channel&mask) + dest[i]=src[i]; } -static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, int * comp, stbi_uc * result) { - int act_comp = 0, num_packets = 0, y, chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) +{ + int act_comp=0,num_packets=0,y,chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data + // this will (should...) cater for even some bizarre stuff like having data // for the same channel in multiple packets. - do { - stbi__pic_packet * packet; + do { + stbi__pic_packet *packet; - if (num_packets == sizeof(packets) / sizeof(packets[0])) - return stbi__errpuc("bad format", "too many packets"); + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return stbi__errpuc("bad format","too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (reading packets)"); - if (packet->size != 8) - return stbi__errpuc("bad format", "packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); + if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for (y = 0; y < height; ++y) { - int packet_idx; + for(y=0; ytype) { + switch (packet->type) { default: - return stbi__errpuc("bad format", "packet has bad compression type"); + return stbi__errpuc("bad format","packet has bad compression type"); - case 0: { // uncompressed - int x; + case 0: {//uncompressed + int x; - for (x = 0; x < width; ++x, dest += 4) - if (!stbi__readval(s, packet->channel, dest)) - return 0; - break; + for(x=0;xchannel,dest)) + return 0; + break; } - case 1: // Pure RLE - { - int left = width, i; + case 1://Pure RLE + { + int left=width, i; - while (left > 0) { - stbi_uc count, value[4]; + while (left>0) { + stbi_uc count,value[4]; - count = stbi__get8(s); - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (pure read count)"); + count=stbi__get8(s); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - if (count > left) - count = (stbi_uc)left; + if (count > left) + count = (stbi_uc) left; - if (!stbi__readval(s, packet->channel, value)) + if (!stbi__readval(s,packet->channel,value)) return 0; + + for(i=0; ichannel,dest,value); + left -= count; + } + } + break; + + case 2: {//Mixed RLE + int left=width; + while (left>0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count==128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file","scanline overrun"); + + if (!stbi__readval(s,packet->channel,value)) return 0; - for (i = 0; i < count; ++i, dest += 4) - stbi__copyval(packet->channel, dest, value); - left -= count; - } - } break; + for(i=0;ichannel,dest,value); + } else { // Raw + ++count; + if (count>left) return stbi__errpuc("bad file","scanline overrun"); - case 2: { // Mixed RLE - int left = width; - while (left > 0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count == 128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file", "scanline overrun"); - - if (!stbi__readval(s, packet->channel, value)) - return 0; - - for (i = 0; i < count; ++i, dest += 4) - stbi__copyval(packet->channel, dest, value); - } else { // Raw - ++count; - if (count > left) - return stbi__errpuc("bad file", "scanline overrun"); - - for (i = 0; i < count; ++i, dest += 4) - if (!stbi__readval(s, packet->channel, dest)) - return 0; - } - left -= count; - } - break; + for(i=0;ichannel,dest)) + return 0; + } + left-=count; + } + break; } - } - } - } + } + } + } - return result; + return result; } -static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * result; - int i, x, y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) +{ + stbi_uc *result; + int i, x,y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) - comp = &internal_comp; + if (!comp) comp = &internal_comp; - for (i = 0; i < 92; ++i) - stbi__get8(s); + for (i=0; i<92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (stbi__at_eof(s)) - return stbi__errpuc("bad file", "file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) - return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); // skip `ratio' - stbi__get16be(s); // skip `fields' - stbi__get16be(s); // skip `pad' + stbi__get32be(s); //skip `ratio' + stbi__get16be(s); //skip `fields' + stbi__get16be(s); //skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); - if (!result) - return stbi__errpuc("outofmem", "Out of memory"); - memset(result, 0xff, x * y * 4); + // intermediate buffer is RGBA + result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); + if (!result) return stbi__errpuc("outofmem", "Out of memory"); + memset(result, 0xff, x*y*4); - if (!stbi__pic_load_core(s, x, y, comp, result)) { - STBI_FREE(result); - result = 0; - } - *px = x; - *py = y; - if (req_comp == 0) - req_comp = *comp; - result = stbi__convert_format(result, 4, req_comp, x, y); + if (!stbi__pic_load_core(s,x,y,comp, result)) { + STBI_FREE(result); + result=0; + } + *px = x; + *py = y; + if (req_comp == 0) req_comp = *comp; + result=stbi__convert_format(result,4,req_comp,x,y); - return result; + return result; } -static int stbi__pic_test(stbi__context * s) { - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) +{ + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6926,968 +6550,931 @@ static int stbi__pic_test(stbi__context * s) { // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct { - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct +{ + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct { - int w, h; - stbi_uc * out; // output buffer (always 4 components) - stbi_uc * background; // The current "background" as far as a gif is concerned - stbi_uc * history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc * color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct +{ + int w,h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context * s) { - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') - return 0; - if (stbi__get8(s) != 'a') - return 0; - return 1; +static int stbi__gif_test_raw(stbi__context *s) +{ + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') return 0; + if (stbi__get8(s) != 'a') return 0; + return 1; } -static int stbi__gif_test(stbi__context * s) { - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__gif_test(stbi__context *s) +{ + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; } -static void stbi__gif_parse_colortable(stbi__context * s, stbi_uc pal[256][4], int num_entries, int transp) { - int i; - for (i = 0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) +{ + int i; + for (i=0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } } -static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is_info) { - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) +{ + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); - version = stbi__get8(s); - if (version != '7' && version != '9') - return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') - return stbi__err("not GIF", "Corrupt GIF"); + version = stbi__get8(s); + if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; - if (g->w > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) - return stbi__err("too large", "Very large image (corrupt?)"); + if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (comp != 0) - *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments - if (is_info) - return 1; + if (is_info) return 1; - if (g->flags & 0x80) - stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + if (g->flags & 0x80) + stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); - return 1; + return 1; } -static int stbi__gif_info_raw(stbi__context * s, int * x, int * y, int * comp) { - stbi__gif * g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); - if (!g) - return stbi__err("outofmem", "Out of memory"); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind(s); - return 0; - } - if (x) - *x = g->w; - if (y) - *y = g->h; - STBI_FREE(g); - return 1; +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); + if (!g) return stbi__err("outofmem", "Out of memory"); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind( s ); + return 0; + } + if (x) *x = g->w; + if (y) *y = g->h; + STBI_FREE(g); + return 1; } -static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { - stbi_uc *p, *c; - int idx; +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) +{ + stbi_uc *p, *c; + int idx; - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); - if (g->cur_y >= g->max_y) - return; + if (g->cur_y >= g->max_y) return; - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; - } - } + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } } -static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw * p; +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) +{ + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) - return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc)init_code; - g->codes[init_code].suffix = (stbi_uc)init_code; - } + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc) init_code; + g->codes[init_code].suffix = (stbi_uc) init_code; + } - // support no starting clear code - avail = clear + 2; - oldcode = -1; + // support no starting clear code + avail = clear+2; + oldcode = -1; - len = 0; - for (;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; + len = 0; + for(;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32) stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s,len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); } - --len; - bits |= (stbi__int32)stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s, len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - p->prefix = (stbi__int16)oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + p->prefix = (stbi__int16) oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__out_gif_code(g, (stbi__uint16)code); + stbi__out_gif_code(g, (stbi__uint16) code); - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; - } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; } - } - } + + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } } // this function is designed to support animated gifs, although stb_image doesn't support it // two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * comp, int req_comp, stbi_uc * two_back) { - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) +{ + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp, 0)) - return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *)stbi__malloc(4 * pcount); - g->background = (stbi_uc *)stbi__malloc(4 * pcount); - g->history = (stbi_uc *)stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); + // on first frame, any non-written pixels get the background colour (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *) stbi__malloc(4 * pcount); + g->background = (stbi_uc *) stbi__malloc(4 * pcount); + g->history = (stbi_uc *) stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; + // image is treated as "transparent" at the start - ie, nothing overwrites the current background; + // background colour is only used for pixels that are not rendered first frame, after that "background" + // color refers to the color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background - } + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the old background + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); - } + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); - } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - // background is what out is after the undoing of the previou frame; - memcpy(g->background, g->out, 4 * g->w * g->h); - } + // background is what out is after the undoing of the previou frame; + memcpy( g->background, g->out, 4 * g->w * g->h ); + } - // clear my history; - memset(g->history, 0x00, g->w * g->h); // pixels that were affected previous frame + // clear my history; + memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { stbi__int32 x, y, w, h; - stbi_uc * o; + stbi_uc *o; x = stbi__get16le(s); y = stbi__get16le(s); w = stbi__get16le(s); h = stbi__get16le(s); if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); g->line_size = g->w * 4; g->start_x = x * 4; g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; // if the width of the specified rectangle is 0, that means // we may not see *any* pixels or the image is malformed; // to make sure this is caught, move the current y down to // max_y (which is what out_gif_code checks). if (w == 0) - g->cur_y = g->max_y; + g->cur_y = g->max_y; g->lflags = stbi__get8(s); if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; } else { - g->step = g->line_size; - g->parse = 0; + g->step = g->line_size; + g->parse = 0; } if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *)g->lpal; + stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *) g->lpal; } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *)g->pal; + g->color_table = (stbi_uc *) g->pal; } else - return stbi__errpuc("missing color table", "Corrupt GIF"); + return stbi__errpuc("missing color table", "Corrupt GIF"); o = stbi__process_gif_raster(s, g); - if (!o) - return NULL; + if (!o) return NULL; // if this was the first frame, pcount = g->w * g->h; if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = - 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); - } - } + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); + } + } } return o; - } + } - case 0x21: // Comment Extension. - { + case 0x21: // Comment Extension. + { int len; int ext = stbi__get8(s); if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; + } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); + stbi__skip(s, len); } break; - } + } - case 0x3B: // gif stream termination code - return (stbi_uc *)s; // using '1' causes warning on some compilers + case 0x3B: // gif stream termination code + return (stbi_uc *) s; // using '1' causes warning on some compilers - default: + default: return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } + } + } } -static void * stbi__load_gif_main_outofmem(stbi__gif * g, stbi_uc * out, int ** delays) { - STBI_FREE(g->out); - STBI_FREE(g->history); - STBI_FREE(g->background); +static void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays) +{ + STBI_FREE(g->out); + STBI_FREE(g->history); + STBI_FREE(g->background); - if (out) - STBI_FREE(out); - if (delays && *delays) - STBI_FREE(*delays); - return stbi__errpuc("outofmem", "Out of memory"); + if (out) STBI_FREE(out); + if (delays && *delays) STBI_FREE(*delays); + return stbi__errpuc("outofmem", "Out of memory"); } -static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp) { - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc * u = 0; - stbi_uc * out = 0; - stbi_uc * two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; - STBI_NOTUSED(out_size); - STBI_NOTUSED(delays_size); + STBI_NOTUSED(out_size); + STBI_NOTUSED(delays_size); - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; - if (out) { - void * tmp = (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); - if (!tmp) - return stbi__load_gif_main_outofmem(&g, out, delays); - else { - out = (stbi_uc *)tmp; - out_size = layers * stride; - } + if (out) { + void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); + if (!tmp) + return stbi__load_gif_main_outofmem(&g, out, delays); + else { + out = (stbi_uc*) tmp; + out_size = layers * stride; + } - if (delays) { - int * new_delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, sizeof(int) * layers); - if (!new_delays) - return stbi__load_gif_main_outofmem(&g, out, delays); - *delays = new_delays; - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc *)stbi__malloc(layers * stride); - if (!out) - return stbi__load_gif_main_outofmem(&g, out, delays); - out_size = layers * stride; - if (delays) { - *delays = (int *)stbi__malloc(layers * sizeof(int)); - if (!*delays) - return stbi__load_gif_main_outofmem(&g, out, delays); - delays_size = layers * sizeof(int); - } - } - memcpy(out + ((layers - 1) * stride), u, stride); - if (layers >= 2) { - two_back = out - 2 * stride; - } - - if (delays) { - (*delays)[layers - 1U] = g.delay; - } + if (delays) { + int *new_delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); + if (!new_delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + *delays = new_delays; + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc*)stbi__malloc( layers * stride ); + if (!out) + return stbi__load_gif_main_outofmem(&g, out, delays); + out_size = layers * stride; + if (delays) { + *delays = (int*) stbi__malloc( layers * sizeof(int) ); + if (!*delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + delays_size = layers * sizeof(int); + } + } + memcpy( out + ((layers - 1) * stride), u, stride ); + if (layers >= 2) { + two_back = out - 2 * stride; } - } while (u != 0); - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } } -static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *)s) - u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - return u; + return u; } -static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp) { return stbi__gif_info_raw(s, x, y, comp); } +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) +{ + return stbi__gif_info_raw(s,x,y,comp); +} #endif // ************************************************************************************************* // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context * s, const char * signature) { - int i; - for (i = 0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) +{ + int i; + for (i=0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; } -static int stbi__hdr_test(stbi__context * s) { - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if (!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; +static int stbi__hdr_test(stbi__context* s) +{ + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if(!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; } -#define STBI__HDR_BUFLEN 1024 -static char * stbi__hdr_gettoken(stbi__context * z, char * buffer) { - int len = 0; - char c = '\0'; +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) +{ + int len=0; + char c = '\0'; - c = (char)stbi__get8(z); + c = (char) stbi__get8(z); - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN - 1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; - } - c = (char)stbi__get8(z); - } + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN-1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char) stbi__get8(z); + } - buffer[len] = 0; - return buffer; + buffer[len] = 0; + return buffer; } -static void stbi__hdr_convert(float * output, stbi_uc * input, int req_comp) { - if (input[3] != 0) { - float f1; - // Exponent - f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; - } - if (req_comp == 2) - output[1] = 1; - if (req_comp == 4) - output[3] = 1; - } else { - switch (req_comp) { - case 4: - output[3] = 1; /* fallthrough */ - case 3: - output[0] = output[1] = output[2] = 0; - break; - case 2: - output[1] = 1; /* fallthrough */ - case 1: - output[0] = 0; - break; - } - } +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) +{ + if ( input[3] != 0 ) { + float f1; + // Exponent + f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) output[1] = 1; + if (req_comp == 4) output[3] = 1; + } else { + switch (req_comp) { + case 4: output[3] = 1; /* fallthrough */ + case 3: output[0] = output[1] = output[2] = 0; + break; + case 2: output[1] = 1; /* fallthrough */ + case 1: output[0] = 0; + break; + } + } } -static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - char buffer[STBI__HDR_BUFLEN]; - char * token; - int valid = 0; - int width, height; - stbi_uc * scanline; - float * hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1, c2, z; - const char * headerToken; - STBI_NOTUSED(ri); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1,c2, z; + const char *headerToken; + STBI_NOTUSED(ri); - // Check identifier - headerToken = stbi__hdr_gettoken(s, buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); + // Check identifier + headerToken = stbi__hdr_gettoken(s,buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); - // Parse header - for (;;) { - token = stbi__hdr_gettoken(s, buffer); - if (token[0] == 0) - break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) - valid = 1; - } + // Parse header + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } - if (!valid) - return stbi__errpf("unsupported format", "Unsupported HDR format"); + if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s, buffer); - if (strncmp(token, "-Y ", 3)) - return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int)strtol(token, &token, 10); - while (*token == ' ') - ++token; - if (strncmp(token, "+X ", 3)) - return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int)strtol(token, NULL, 10); + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int) strtol(token, NULL, 10); - if (height > STBI_MAX_DIMENSIONS) - return stbi__errpf("too large", "Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) - return stbi__errpf("too large", "Very large image (corrupt?)"); + if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - *x = width; - *y = height; + *x = width; + *y = height; - if (comp) - *comp = 3; - if (req_comp == 0) - req_comp = 3; + if (comp) *comp = 3; + if (req_comp == 0) req_comp = 3; - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); - // Read data - hdr_data = (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); + // Read data + hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); - // Load image data - // image data is stored as some number of sca - if (width < 8 || width >= 32768) { - // Read flat data - for (j = 0; j < height; ++j) { - for (i = 0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } - } - } else { - // Read RLE-encoded data - scanline = NULL; + // Load image data + // image data is stored as some number of sca + if ( width < 8 || width >= 32768) { + // Read flat data + for (j=0; j < height; ++j) { + for (i=0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); + } + } + } else { + // Read RLE-encoded data + scanline = NULL; - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc)c1; - rgbe[1] = (stbi_uc)c2; - rgbe[2] = (stbi_uc)len; - rgbe[3] = (stbi_uc)stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); - } - if (scanline == NULL) { - scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); - } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if ((count == 0) || (count > nleft)) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("corrupt", "bad RLE data in HDR"); - } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if ((count == 0) || (count > nleft)) { - STBI_FREE(hdr_data); - STBI_FREE(scanline); - return stbi__errpf("corrupt", "bad RLE data in HDR"); - } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } - } - } - for (i = 0; i < width; ++i) - stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, scanline + i * 4, req_comp); - } - if (scanline) + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a decoded + // pixel (note this can't be a valid pixel--one of RGB must be >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc) c1; + rgbe[1] = (stbi_uc) c2; + rgbe[2] = (stbi_uc) len; + rgbe[3] = (stbi_uc) stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; STBI_FREE(scanline); - } + goto main_decode_loop; // yes, this makes no sense + } + len <<= 8; + len |= stbi__get8(s); + if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } + if (scanline == NULL) { + scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } + } - return hdr_data; + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } + } + for (i=0; i < width; ++i) + stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; } -static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp) { - char buffer[STBI__HDR_BUFLEN]; - char * token; - int valid = 0; - int dummy; +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; - if (stbi__hdr_test(s) == 0) { - stbi__rewind(s); - return 0; - } + if (stbi__hdr_test(s) == 0) { + stbi__rewind( s ); + return 0; + } - for (;;) { - token = stbi__hdr_gettoken(s, buffer); - if (token[0] == 0) - break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) - valid = 1; - } + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } - if (!valid) { - stbi__rewind(s); - return 0; - } - token = stbi__hdr_gettoken(s, buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind(s); - return 0; - } - token += 3; - *y = (int)strtol(token, &token, 10); - while (*token == ' ') - ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind(s); - return 0; - } - token += 3; - *x = (int)strtol(token, NULL, 10); - *comp = 3; - return 1; + if (!valid) { + stbi__rewind( s ); + return 0; + } + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *y = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *x = (int) strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp) { - void * p; - stbi__bmp_data info; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) +{ + void *p; + stbi__bmp_data info; - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - if (p == NULL) { - stbi__rewind(s); - return 0; - } - if (x) - *x = s->img_x; - if (y) - *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + if (p == NULL) { + stbi__rewind( s ); + return 0; + } + if (x) *x = s->img_x; + if (y) *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp) { - int channelCount, dummy, depth; - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind(s); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind(s); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind(s); - return 0; - } - *comp = 4; - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) +{ + int channelCount, dummy, depth; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind( s ); + return 0; + } + *comp = 4; + return 1; } -static int stbi__psd_is16(stbi__context * s) { - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind(s); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind(s); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind(s); - return 0; - } - STBI_NOTUSED(stbi__get32be(s)); - STBI_NOTUSED(stbi__get32be(s)); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind(s); - return 0; - } - return 1; +static int stbi__psd_is16(stbi__context *s) +{ + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + STBI_NOTUSED(stbi__get32be(s)); + STBI_NOTUSED(stbi__get32be(s)); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind( s ); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { - int act_comp = 0, num_packets = 0, chained, dummy; - stbi__pic_packet packets[10]; +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) +{ + int act_comp=0,num_packets=0,chained,dummy; + stbi__pic_packet packets[10]; - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; - if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { - stbi__rewind(s); - return 0; - } + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } - stbi__skip(s, 88); + stbi__skip(s, 88); - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind(s); - return 0; - } - if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind(s); - return 0; - } + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind( s); + return 0; + } + if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind( s ); + return 0; + } - stbi__skip(s, 8); + stbi__skip(s, 8); - do { - stbi__pic_packet * packet; + do { + stbi__pic_packet *packet; - if (num_packets == sizeof(packets) / sizeof(packets[0])) - return 0; + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return 0; - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - if (stbi__at_eof(s)) { - stbi__rewind(s); - return 0; - } - if (packet->size != 8) { - stbi__rewind(s); - return 0; - } - } while (chained); + if (stbi__at_eof(s)) { + stbi__rewind( s ); + return 0; + } + if (packet->size != 8) { + stbi__rewind( s ); + return 0; + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7904,271 +7491,272 @@ static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context * s) { - char p, t; - p = (char)stbi__get8(s); - t = (char)stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) +{ + char p, t; + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind( s ); + return 0; + } + return 1; } -static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { - stbi_uc * out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + STBI_NOTUSED(ri); - ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); - if (ri->bits_per_channel == 0) - return 0; + ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); + if (ri->bits_per_channel == 0) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) - return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) - *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; - if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *)stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); - if (!out) - return stbi__errpuc("outofmem", "Out of memory"); - if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { - STBI_FREE(out); - return stbi__errpuc("bad PNM", "PNM file truncated"); - } + out = (stbi_uc *) stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { + STBI_FREE(out); + return stbi__errpuc("bad PNM", "PNM file truncated"); + } - if (req_comp && req_comp != s->img_n) { - if (ri->bits_per_channel == 16) { - out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, s->img_n, req_comp, s->img_x, s->img_y); - } else { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - } - if (out == NULL) - return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + if (ri->bits_per_channel == 16) { + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, s->img_n, req_comp, s->img_x, s->img_y); + } else { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + } + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } - -static void stbi__pnm_skip_whitespace(stbi__context * s, char * c) { - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char)stbi__get8(s); - - if (stbi__at_eof(s) || *c != '#') - break; - - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') - *c = (char)stbi__get8(s); - } +static int stbi__pnm_isspace(char c) +{ + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } -static int stbi__pnm_isdigit(char c) { return c >= '0' && c <= '9'; } +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) +{ + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char) stbi__get8(s); -static int stbi__pnm_getinteger(stbi__context * s, char * c) { - int value = 0; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value * 10 + (*c - '0'); - *c = (char)stbi__get8(s); - if ((value > 214748364) || (value == 214748364 && *c > '7')) - return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); - } - - return value; + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) + *c = (char) stbi__get8(s); + } } -static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { - int maxv, dummy; - char c, p, t; - - if (!x) - x = &dummy; - if (!y) - y = &dummy; - if (!comp) - comp = &dummy; - - stbi__rewind(s); - - // Get identifier - p = (char)stbi__get8(s); - t = (char)stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } - - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - - c = (char)stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); - - *x = stbi__pnm_getinteger(s, &c); // read width - if (*x == 0) - return stbi__err("invalid width", "PPM image header had zero or overflowing width"); - stbi__pnm_skip_whitespace(s, &c); - - *y = stbi__pnm_getinteger(s, &c); // read height - if (*y == 0) - return stbi__err("invalid width", "PPM image header had zero or overflowing width"); - stbi__pnm_skip_whitespace(s, &c); - - maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 65535) - return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); - else if (maxv > 255) - return 16; - else - return 8; +static int stbi__pnm_isdigit(char c) +{ + return c >= '0' && c <= '9'; } -static int stbi__pnm_is16(stbi__context * s) { - if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) - return 1; - return 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) +{ + int value = 0; + + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value*10 + (*c - '0'); + *c = (char) stbi__get8(s); + if((value > 214748364) || (value == 214748364 && *c > '7')) + return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); + } + + return value; +} + +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) +{ + int maxv, dummy; + char c, p, t; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + stbi__rewind(s); + + // Get identifier + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + + c = (char) stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); + + *x = stbi__pnm_getinteger(s, &c); // read width + if(*x == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + *y = stbi__pnm_getinteger(s, &c); // read height + if (*y == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + maxv = stbi__pnm_getinteger(s, &c); // read max value + if (maxv > 65535) + return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); + else if (maxv > 255) + return 16; + else + return 8; +} + +static int stbi__pnm_is16(stbi__context *s) +{ + if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) + return 1; + return 0; } #endif -static int stbi__info_main(stbi__context * s, int * x, int * y, int * comp) { -#ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) - return 1; -#endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) +{ + #ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) return 1; + #endif -#ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) - return 1; -#endif + #ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) return 1; + #endif -// test tga last because it's a crappy test! -#ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; -#endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); + // test tga last because it's a crappy test! + #ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; + #endif + return stbi__err("unknown image type", "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context * s) { -#ifndef STBI_NO_PNG - if (stbi__png_is16(s)) - return 1; -#endif +static int stbi__is_16_main(stbi__context *s) +{ + #ifndef STBI_NO_PNG + if (stbi__png_is16(s)) return 1; + #endif -#ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) - return 1; -#endif + #ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) return 1; + #endif -#ifndef STBI_NO_PNM - if (stbi__pnm_is16(s)) - return 1; -#endif - return 0; + #ifndef STBI_NO_PNM + if (stbi__pnm_is16(s)) return 1; + #endif + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) +{ + FILE *f = stbi__fopen(filename, "rb"); int result; - if (!f) - return stbi__err("can't fopen", "Unable to open file"); + if (!f) return stbi__err("can't fopen", "Unable to open file"); result = stbi_info_from_file(f, x, y, comp); fclose(f); return result; } -STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp) { - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s, x, y, comp); - fseek(f, pos, SEEK_SET); - return r; +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s,x,y,comp); + fseek(f,pos,SEEK_SET); + return r; } -STBIDEF int stbi_is_16_bit(char const * filename) { - FILE * f = stbi__fopen(filename, "rb"); +STBIDEF int stbi_is_16_bit(char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); int result; - if (!f) - return stbi__err("can't fopen", "Unable to open file"); + if (!f) return stbi__err("can't fopen", "Unable to open file"); result = stbi_is_16_bit_from_file(f); fclose(f); return result; } -STBIDEF int stbi_is_16_bit_from_file(FILE * f) { - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f, pos, SEEK_SET); - return r; +STBIDEF int stbi_is_16_bit_from_file(FILE *f) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f,pos,SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__info_main(&s, x, y, comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__info_main(&s,x,y,comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * c, void * user, int * x, int * y, int * comp) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); - return stbi__info_main(&s, x, y, comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__info_main(&s,x,y,comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len) { - stbi__context s; - stbi__start_mem(&s, buffer, len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * user) { - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION @@ -8279,9 +7867,12 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * us 1.30 (2011-06-11) added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks - anyway error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in - decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from Aurelien Pocheville 1.28 (2010-08-01) + removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway + error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) + fix inefficiency in decoding 32-bit BMP (David Woo) + 1.29 (2010-08-16) + various warning fixes from Aurelien Pocheville + 1.28 (2010-08-01) fix bug in GIF palette transparency (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings @@ -8353,6 +7944,7 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * us first released version */ + /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 550dd5cfd..caa41aee5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -63,6 +63,7 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + is_lora: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -70,7 +71,7 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -92,6 +93,7 @@ class Model: self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -295,6 +297,7 @@ class Model: gguf.MODEL_TENSOR.FFN_GATE_INP, gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, + gguf.MODEL_TENSOR.SSM_CONV1D, ) ) or not name.endswith(".weight") @@ -590,6 +593,15 @@ class Model: if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M res = "smollm" + if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7": + # ref: https://huggingface.co/bigscience/bloom + res = "bloom" + if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21": + # ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small + res = "gpt3-finnish" + if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae": + # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct + res = "exaone" if res is None: logger.warning("\n") @@ -893,7 +905,7 @@ class GPTNeoXModel(Model): return tensors -@Model.register("BloomForCausalLM") +@Model.register("BloomForCausalLM", "BloomModel") class BloomModel(Model): model_arch = gguf.MODEL_ARCH.BLOOM @@ -1560,7 +1572,7 @@ class LlamaModel(Model): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -1583,7 +1595,8 @@ class LlamaModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + if not self.is_lora: + self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) super().prepare_tensors() @@ -2130,8 +2143,9 @@ class Phi3MiniModel(Model): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + if not self.is_lora: + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) @Model.register("PlamoForCausalLM") @@ -2702,7 +2716,7 @@ class StarCoder2Model(Model): model_arch = gguf.MODEL_ARCH.STARCODER2 -@Model.register("MambaForCausalLM", "MambaLMHeadModel") +@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA @@ -2733,7 +2747,10 @@ class MambaModel(Model): # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 - + use_dt_b_c_norm = False + # For falconmamba we do apply RMS norm on B / DT and C layers + if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",): + use_dt_b_c_norm = True # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model @@ -2741,12 +2758,13 @@ class MambaModel(Model): self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers self.gguf_writer.add_file_type(self.ftype) _tok_embd = None @@ -2773,23 +2791,6 @@ class MambaModel(Model): return [(new_name, data_torch)] - def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: - if bid is not None and new_name in ( - self.format_tensor_name( - n, bid, ".weight" if name.endswith(".weight") else "" - ) - for n in [ - gguf.MODEL_TENSOR.SSM_CONV1D, - gguf.MODEL_TENSOR.SSM_X, - gguf.MODEL_TENSOR.SSM_DT, - gguf.MODEL_TENSOR.SSM_A, - gguf.MODEL_TENSOR.SSM_D, - ] - ): - return gguf.GGMLQuantizationType.F32 - - return super().tensor_force_quant(name, new_name, bid, n_dims) - @Model.register("CohereForCausalLM") class CommandR2Model(Model): @@ -3734,8 +3735,121 @@ class ChatGLMModel(Model): name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] -###### CONVERSION LOGIC ###### +@Model.register("NemotronForCausalLM") +class NemotronModel(Model): + model_arch = gguf.MODEL_ARCH.NEMOTRON + + def set_vocab(self): + self._set_vocab_sentencepiece() + self.gguf_writer.add_pad_token_id(0) + self.gguf_writer.add_unk_token_id(1) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"]) + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + + # * Partial RoPE + rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + + # * RopeScaling for Nemotron + if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side + # model.layers.{l}.input_layernorm.weight + # model.layers.{l}.post_attention_layernorm.weight + # model.norm.weight + if name.endswith("norm.weight"): + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + +@Model.register("ExaoneForCausalLM") +class ExaoneModel(Model): + model_arch = gguf.MODEL_ARCH.EXAONE + + def set_gguf_parameters(self): + hparams = self.hparams + + assert (hparams["activation_function"] == "silu") + + max_position_embeddings = hparams["max_position_embeddings"] + embed_dim = hparams["hidden_size"] + num_heads = hparams["num_attention_heads"] + num_kv_heads = hparams.get("num_key_value_heads", num_heads) + layer_norm_eps = hparams["layer_norm_epsilon"] + intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim + num_layers = hparams["num_layers"] + # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 + # attention_dropout_rate = hparams["attention_dropout"] + # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 + # embed_dropout_rate = hparams["embed_dropout"] + self.gguf_writer.add_embedding_length(embed_dim) + self.gguf_writer.add_head_count(num_heads) + self.gguf_writer.add_head_count_kv(num_kv_heads) + self.gguf_writer.add_context_length(max_position_embeddings) + self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_file_type(self.ftype) + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = rotary_factor if rotary_factor is not None else 1.0 + self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]: + if hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + + def prepare_tensors(self): + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + if not self.is_lora: + self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + + super().prepare_tensors() + + +###### CONVERSION LOGIC ###### # tree of lazy tensors class LazyTorchTensor(gguf.LazyBase): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d5a2d925e..ff4955f9c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -94,6 +94,9 @@ models = [ {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, + {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, + {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, ] diff --git a/convert_llama_ggml_to_gguf.py b/convert_llama_ggml_to_gguf.py index 7b00b4398..29b14e98d 100755 --- a/convert_llama_ggml_to_gguf.py +++ b/convert_llama_ggml_to_gguf.py @@ -116,7 +116,7 @@ class Tensor: assert quant is not None, 'Unknown tensor type' (blksize, tysize) = quant offset += 12 - self.dtype= dtype + self.dtype= gguf.GGMLQuantizationType(dtype) self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) offset += 4 * n_dims self.name = bytes(data[offset:offset + name_len]) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a88d0d4a9..ddd347a2a 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -386,6 +386,7 @@ if __name__ == '__main__': dry_run=args.dry_run, dir_lora_model=dir_lora, lora_alpha=alpha, + is_lora=True, ) logger.info("Exporting model...") diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md new file mode 100644 index 000000000..6bdd9d2da --- /dev/null +++ b/docs/backend/CANN.md @@ -0,0 +1,259 @@ +# llama.cpp for CANN + + - [Background](#background) + - [News](#news) + - [OS](#os) + - [Hardware](#hardware) + - [Model Supports](#model-supports) + - [DataType Supports](#datatype-supports) + - [Docker](#docker) + - [Linux](#linux) + - [TODO](#todo) + + +## Background + +**Ascend NPU** is a range of AI processors using Neural Processing Unit. It will efficiently handle matrix-matrix multiplication, dot-product and scalars. + +**CANN** (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for AI scenarios, providing support for multiple AI frameworks on the top and serving AI processors and programming at the bottom. It plays a crucial role in bridging the gap between upper and lower layers, and is a key platform for improving the computing efficiency of Ascend AI processors. Meanwhile, it offers a highly efficient and easy-to-use programming interface for diverse application scenarios, allowing users to rapidly build AI applications and services based on the Ascend platform. + +**Llama.cpp + CANN** + +The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly. + +## News + +- 2024.8 + - Support `Q4_0` and `Q8_0` data type for Ascend NPU. +- 2024.7 + - Create CANN backend for Ascend NPU. + +## OS + +| OS | Status | Verified | +|:-------:|:-------:|:----------------------------------------------:| +| Linux | Support | Ubuntu 22.04, OpenEuler22.03 | + + +## Hardware + +### Ascend NPU + +**Verified devices** +| Ascend NPU | Status | +|:-----------------------------:|:-------:| +| Atlas 300T A2 | Support | + +*Notes:* + +- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag. +- If you run successfully with your Ascend NPU device, please help update the upper table. + + +## Model Supports + +| Model Name | FP16 | Q8_0 | Q4_0 | +|:----------------------------|:-----:|:----:|:----:| +| AquilaChat2-7B | √ | √ | √ | +| Baichuan-7b | √ | √ | √ | +| Baichuan2-7B-Chat | √ | √ | √ | +| bitnet_b1_58-large | √ | √ | √ | +| bloom-560m | √ | x | √ | +| bloomz-alpaca-560m | √ | x | √ | +| c4ai-command-r-35B-v01 | x | x | x | +| chatglm3-6B | x | x | x | +| chinese-alpaca-2-1.3b | √ | √ | √ | +| CodeShell-7B | √ | √ | √ | +| deepseek-ai_deepseek-coder-1.3B-base | x | x | x | +| deepseek-ai_DeepSeek-V2-Lite | x | x | x | +| deepseek-coder-6.7B-instruct | x | x | x | +| DeepSeek-V2-Lite-64x1.5B | x | x | x | +| falcon-7b-instruct | √ | √ | √ | +| flan-t5-large | √ | √ | √ | +| gemma-2-9b-it | √ | √ | √ | +| glm-4-9B | x | x | x | +| gpt2 | √ | √ | √ | +| Gpt2-163M | √ | √ | √ | +| granite-3B-code-instruct | √ | √ | √ | +| GritLM-7B | √ | √ | √ | +| internlm2_5-7b-chat | √ | √ | √ | +| koala-7B-HF | √ | √ | √ | +| Llama-2-7b-chat-hf | √ | √ | √ | +| Llama-3-Smaug-8B | √ | √ | √ | +| Llama2-Chinese-7b-Chat | √ | √ | √ | +| Llama3-8B | √ | √ | √ | +| Llama3-8b-chinese | √ | √ | √ | +| mamba-130m-hf | √ | √ | √ | +| Mistral-7B-Instruct-v0.2 | √ | √ | √ | +| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ | +| mpt-7B | √ | √ | √ | +| OLMo-1B-hf | √ | √ | √ | +| OpenELM-3B-Instruct | √ | √ | √ | +| Orion-14b-base | √ | √ | √ | +| phi1 | x | x | x | +| phi2 | x | x | x | +| Phi-3-mini-4k-instruct | √ | √ | √ | +| plamo-13b | √ | √ | √ | +| pythia-70M | x | x | x | +| Qwen-7B | √ | √ | √ | +| Qwen2-1.5B-Instruct | √ | x | √ | +| Refact-1_6B-fim | √ | √ | √ | +| SmolLM-135M | √ | √ | √ | +| stablelm-zephyr | x | x | x | +| stablelm-2-zephyr-1_6b | x | x | x | +| starcoderbase-1b | √ | √ | √ | +| starcoder2-3b | √ | √ | √ | +| vigogne-7b-chat | √ | √ | √ | +| xverse-7b-chat | √ | √ | √ | +| Yi-6b-Chat | √ | √ | √ | + + + +## DataType Supports + +| DataType | Status | +|:----------------------:|:-------:| +| FP16 | Support | +| Q8_0 | Support | +| Q4_0 | Support | + +## Docker + +### Build Images +You can get a image with llama.cpp in one command. +```sh +docker build -t llama-cpp-cann -f .devops/llama-cli-cann.Dockerfile . +``` + +### Run container + +```sh +# Find all cards. +npu-smi info + +# Select the cards that you want to use, make sure these cards are not used by someone. +# Following using cards of device0. +docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:" +``` + +*Notes:* + +- You may need to install Ascend Driver and firmware on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*. + +## Linux + +### I. Setup Environment + +1. **Install Ascend Driver and firmware** + + ```sh + # create driver running user. + sudo groupadd -g HwHiAiUser + sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash + sudo usermod -aG HwHiAiUser $USER + + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all + ``` + + Once installed, run `npu-smi info` to check whether driver is installed successfully. + ```sh + +-------------------------------------------------------------------------------------------+ + | npu-smi 24.1.rc2 Version: 24.1.rc2 | + +----------------------+---------------+----------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)| + | Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) | + +======================+===============+====================================================+ + | 2 xxx | OK | 64.4 51 15 / 15 | + | 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | 5 xxx | OK | 64.0 52 15 / 15 | + | 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 | + +======================+===============+====================================================+ + | No running processes found in NPU 2 | + +======================+===============+====================================================+ + | No running processes found in NPU 5 | + +======================+===============+====================================================+ + ``` + +2. **Install Ascend Firmware** + ```sh + # download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system + # and install driver. + sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full + ``` + If the following messaage appers, firmware is installed successfully. + ```sh + Firmware package installed successfully! + ``` + + +3. **Install CANN toolkit and kernels** + + CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page. + + Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command. + ```sh + pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions + sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install + sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install + ``` + + Set Ascend Variables: + ```sh + echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc + source ~/.bashrc + ``` + +Upon a successful installation, CANN is enabled for the available ascend devices. + +### II. Build llama.cpp + +```sh +cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release +cmake --build build --config release +``` + +### III. Run the inference + +1. **Retrieve and prepare model** + + You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration. + + **Notes**: + + - CANN backend only supports FP16/Q4_0/Q8_0 models currently. + +2. **Launch inference** + + There are two device selection modes: + + - Single device: Use one device target specified by the user. + - Multiple devices: Automatically choose the devices with the same backend. + + | Device selection | Parameter | + |:----------------:|:--------------------------------------:| + | Single device | --split-mode none --main-gpu DEVICE_ID | + | Multiple devices | --split-mode layer (default) | + + Examples: + + - Use device 0: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 + ``` + + - Use multiple devices: + + ```sh + ./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer + ``` + +### **GitHub contribution**: +Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. + + +## TODO +- Support more models and data types. diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 59a39fbb6..e3b9572cc 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -20,7 +20,7 @@ **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. -- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*. +- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. @@ -28,10 +28,6 @@ The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). -When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend. - -It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose. - ## Recommended Release The SYCL backend would be broken by some PRs due to no online CI. @@ -45,6 +41,10 @@ The following release is verified with good quality: ## News + +- 2024.8 + - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. + - 2024.5 - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. - Arch Linux is verified successfully. @@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. -Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs. +Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs. - **Adding support to Nvidia GPUs** @@ -255,8 +255,6 @@ or # Export relevant ENV variables source /opt/intel/oneapi/setvars.sh -# Build LLAMA with MKL BLAS acceleration for intel GPU - # Option 1: Use FP32 (recommended for better performance in most cases) cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx @@ -338,12 +336,12 @@ Choose one of following methods to run. - Use device 0: ```sh -./examples/sycl/run_llama2.sh 0 +./examples/sycl/run-llama2.sh 0 ``` - Use multiple devices: ```sh -./examples/sycl/run_llama2.sh +./examples/sycl/run-llama2.sh ``` 2. Command line diff --git a/docs/build.md b/docs/build.md index 8b16d1a35..152d46d6f 100644 --- a/docs/build.md +++ b/docs/build.md @@ -352,6 +352,31 @@ cmake --build build --config Release # ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32 ``` +### CANN +This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU. + +For more information about Ascend NPU in [Ascend Community](https://www.hiascend.com/en/). + +Make sure to have the CANN toolkit installed. You can download it from here: [CANN Toolkit](https://www.hiascend.com/developer/download/community/result?module=cann) + +Go to `llama.cpp` directory and build using CMake. +```bash +cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release +cmake --build build --config release +``` + +You can test with: + +`./build/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32` + +If the fllowing info is output on screen, you are using `llama.cpp by CANN backend`: +```bash +llm_load_tensors: CANN buffer size = 13313.00 MiB +llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB +``` + +For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md). + ### Android To read documentation for how to build on Android, [click here](./android.md) diff --git a/docs/docker.md b/docs/docker.md index d8922d77d..e25838255 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -66,8 +66,8 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment The defaults are: -- `CUDA_VERSION` set to `11.7.1` -- `CUDA_DOCKER_ARCH` set to `all` +- `CUDA_VERSION` set to `12.6.0` +- `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures The resulting images, are essentially the same as the non-CUDA images: diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index aca332e94..3ce91070b 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -18,7 +18,7 @@ constexpr float rms_norm_eps = 5e-6f; #endif static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); if (plan.work_size > 0) { buf.resize(plan.work_size); diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp index 47cb16c69..97622f4f4 100644 --- a/examples/benchmark/benchmark-matmult.cpp +++ b/examples/benchmark/benchmark-matmult.cpp @@ -21,7 +21,7 @@ #endif static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); if (plan.work_size > 0) { buf.resize(plan.work_size); @@ -54,7 +54,7 @@ static void tensor_dump(const ggml_tensor * tensor, const char * name) { #define TENSOR_DUMP(tensor) tensor_dump(tensor, #tensor) struct benchmark_params_struct { - int32_t n_threads = 1; + int n_threads = 1; int32_t n_iterations = 10; }; diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index a12e90d82..a68268388 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -271,7 +271,7 @@ struct tokenized_prompt { size_t max_seq_len; tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); @@ -486,8 +486,8 @@ int main(int argc, char ** argv) { if (use_pca) { // run PCA PCA::pca_params pca_params; - pca_params.n_threads = params.n_threads; - pca_params.n_batch = params.n_pca_batch; + pca_params.n_threads = params.cpuparams.n_threads; + pca_params.n_batch = params.n_pca_batch; pca_params.n_iterations = params.n_pca_iterations; PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); } else { diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index ef35ba2c0..5e89988e2 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -127,7 +127,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } static bool run(llama_context * ctx, const gpt_params & params) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index c7e5ca788..8df457e21 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -410,7 +410,7 @@ int main(int argc, char ** argv) { g_verbose = (params.verbosity == 1); try { - lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads); + lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.cpuparams.n_threads); ctx.run_merge(); } catch (const std::exception & err) { fprintf(stderr, "%s\n", err.what()); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 58814b96e..83b85d72b 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -433,8 +433,8 @@ static void process_logits( } static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const int n_ctx = llama_n_ctx(ctx); auto tim1 = std::chrono::high_resolution_clock::now(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b15..05700c1d5 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -203,8 +203,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); } - const bool add_bos = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + const bool add_bos = llama_add_bos_token(model); + GGML_ASSERT(!llama_add_eos_token(model)); LOG("add_bos: %d\n", add_bos); std::vector embd_inp; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42918bfc7..8edadef90 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include "ggml.h" #include "llama.h" @@ -225,6 +226,9 @@ struct cmd_params { std::vector type_k; std::vector type_v; std::vector n_threads; + std::vector cpu_mask; + std::vector cpu_strict; + std::vector poll; std::vector n_gpu_layers; std::vector rpc_servers; std::vector split_mode; @@ -236,6 +240,8 @@ struct cmd_params { std::vector embeddings; ggml_numa_strategy numa; int reps; + ggml_sched_priority prio; + int delay; bool verbose; output_formats output_format; output_formats output_format_stderr; @@ -251,6 +257,9 @@ static const cmd_params cmd_params_defaults = { /* type_k */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16}, /* n_threads */ {cpu_get_num_math()}, + /* cpu_mask */ {"0x0"}, + /* cpu_strict */ {false}, + /* poll */ {50}, /* n_gpu_layers */ {99}, /* rpc_servers */ {""}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, @@ -262,6 +271,8 @@ static const cmd_params cmd_params_defaults = { /* embeddings */ {false}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, + /* prio */ GGML_SCHED_PRIO_NORMAL, + /* delay */ 0, /* verbose */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, @@ -281,6 +292,9 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); + printf(" -C, --cpu-mask (default: %s)\n", join(cmd_params_defaults.cpu_mask, ",").c_str()); + printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str()); + printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); @@ -292,6 +306,8 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); + printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); + printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); printf(" -o, --output (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); printf(" -oe, --output-err (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); @@ -338,6 +354,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { params.output_format_stderr = cmd_params_defaults.output_format_stderr; params.reps = cmd_params_defaults.reps; params.numa = cmd_params_defaults.numa; + params.prio = cmd_params_defaults.prio; + params.delay = cmd_params_defaults.delay; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -433,6 +451,27 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + } else if (arg == "-C" || arg == "--cpu-mask") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end()); + } else if (arg == "--cpu-strict") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end()); + } else if (arg == "--poll") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.poll.insert(params.poll.end(), p.begin(), p.end()); } else if (arg == "-ngl" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; @@ -541,6 +580,18 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.reps = std::stoi(argv[i]); + } else if (arg == "--prio") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.prio = (enum ggml_sched_priority) std::stoi(argv[i]); + } else if (arg == "--delay") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.delay = std::stoi(argv[i]); } else if (arg == "-o" || arg == "--output") { if (++i >= argc) { invalid_param = true; @@ -585,6 +636,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } + if (params.cpu_mask.empty()) { params.cpu_mask = cmd_params_defaults.cpu_mask; } + if (params.cpu_strict.empty()) { params.cpu_strict = cmd_params_defaults.cpu_strict; } + if (params.poll.empty()) { params.poll = cmd_params_defaults.poll; } return params; } @@ -598,6 +652,9 @@ struct cmd_params_instance { ggml_type type_k; ggml_type type_v; int n_threads; + std::string cpu_mask; + bool cpu_strict; + int poll; int n_gpu_layers; std::string rpc_servers; llama_split_mode split_mode; @@ -667,7 +724,10 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) - for (const auto & nt : params.n_threads) { + for (const auto & nt : params.n_threads) + for (const auto & cm : params.cpu_mask) + for (const auto & cs : params.cpu_strict) + for (const auto & pl : params.poll) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { continue; @@ -681,6 +741,9 @@ static std::vector get_cmd_params_instances(const cmd_param /* .type_k = */ tk, /* .type_v = */ tv, /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, /* .n_gpu_layers = */ nl, /* .rpc_servers = */ rpc, /* .split_mode = */ sm, @@ -707,6 +770,9 @@ static std::vector get_cmd_params_instances(const cmd_param /* .type_k = */ tk, /* .type_v = */ tv, /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, /* .n_gpu_layers = */ nl, /* .rpc_servers = */ rpc, /* .split_mode = */ sm, @@ -733,6 +799,9 @@ static std::vector get_cmd_params_instances(const cmd_param /* .type_k = */ tk, /* .type_v = */ tv, /* .n_threads = */ nt, + /* .cpu_mask = */ cm, + /* .cpu_strict = */ cs, + /* .poll = */ pl, /* .n_gpu_layers = */ nl, /* .rpc_servers = */ rpc, /* .split_mode = */ sm, @@ -769,6 +838,9 @@ struct test { int n_batch; int n_ubatch; int n_threads; + std::string cpu_mask; + bool cpu_strict; + int poll; bool has_rpc; ggml_type type_k; ggml_type type_v; @@ -795,6 +867,9 @@ struct test { n_batch = inst.n_batch; n_ubatch = inst.n_ubatch; n_threads = inst.n_threads; + cpu_mask = inst.cpu_mask; + cpu_strict = inst.cpu_strict; + poll = inst.poll; has_rpc = !inst.rpc_servers.empty(); type_k = inst.type_k; type_v = inst.type_v; @@ -872,13 +947,14 @@ struct test { "cpu_info", "gpu_info", "model_filename", "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", - "n_threads", "type_k", "type_v", + "n_threads", "cpu_mask", "cpu_strict", "poll", + "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", - "avg_ts", "stddev_ts" + "avg_ts", "stddev_ts", }; return fields; } @@ -887,7 +963,7 @@ struct test { static field_type get_field_type(const std::string & field) { if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || - field == "n_threads" || + field == "n_threads" || field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || field == "n_prompt" || field == "n_gen" || @@ -896,6 +972,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || + field == "cpu_strict" || field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } @@ -928,7 +1005,8 @@ struct test { cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), std::to_string(n_batch), std::to_string(n_ubatch), - std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), + std::to_string(n_threads), cpu_mask, std::to_string(cpu_strict), std::to_string(poll), + ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), @@ -1067,7 +1145,7 @@ struct markdown_printer : public printer { return -30; } if (field == "t/s") { - return 16; + return 20; } if (field == "size" || field == "params") { return 10; @@ -1149,6 +1227,15 @@ struct markdown_printer : public printer { if (params.n_threads.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) { fields.emplace_back("n_threads"); } + if (params.cpu_mask.size() > 1 || params.cpu_mask != cmd_params_defaults.cpu_mask) { + fields.emplace_back("cpu_mask"); + } + if (params.cpu_strict.size() > 1 || params.cpu_strict != cmd_params_defaults.cpu_strict) { + fields.emplace_back("cpu_strict"); + } + if (params.poll.size() > 1 || params.poll != cmd_params_defaults.poll) { + fields.emplace_back("poll"); + } if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) { fields.emplace_back("n_batch"); } @@ -1383,6 +1470,8 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + set_process_priority(params.prio); + // initialize printer std::unique_ptr p = create_printer(params.output_format); std::unique_ptr p_err = create_printer(params.output_format_stderr); @@ -1428,6 +1517,28 @@ int main(int argc, char ** argv) { llama_kv_cache_clear(ctx); + // cool off before the test + if (params.delay) { + std::this_thread::sleep_for(std::chrono::seconds(params.delay)); + } + + struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads); + if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) { + LOG_TEE("%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str()); + exit(1); + } + tpp.strict_cpu = t.cpu_strict; + tpp.poll = t.poll; + tpp.prio = params.prio; + + struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + + llama_attach_threadpool(ctx, threadpool, NULL); + // warmup run if (t.n_prompt > 0) { //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); @@ -1466,6 +1577,8 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + + ggml_threadpool_free(threadpool); } llama_free_model(lmodel); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 58c32ca53..48b7840ae 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -71,8 +71,8 @@ actor LlamaContext { var ctx_params = llama_context_default_params() ctx_params.seed = 1234 ctx_params.n_ctx = 2048 - ctx_params.n_threads = UInt32(n_threads) - ctx_params.n_threads_batch = UInt32(n_threads) + ctx_params.n_threads = Int32(n_threads) + ctx_params.n_threads_batch = Int32(n_threads) let context = llama_new_context_with_model(model, ctx_params) guard let context else { diff --git a/examples/llava/README-minicpmv2.5.md b/examples/llava/README-minicpmv2.5.md index 4affc1d0f..1c8498ff9 100644 --- a/examples/llava/README-minicpmv2.5.md +++ b/examples/llava/README-minicpmv2.5.md @@ -15,9 +15,9 @@ cd llama.cpp Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) ```bash -python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 -python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 -python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2 +python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model # quantize int4 version ./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M diff --git a/examples/llava/README-minicpmv2.6.md b/examples/llava/README-minicpmv2.6.md new file mode 100644 index 000000000..c4be5e5dd --- /dev/null +++ b/examples/llava/README-minicpmv2.6.md @@ -0,0 +1,107 @@ +## MiniCPM-V 2.6 + +### Prepare models and code + +Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpmv-main +``` + +### Usage of MiniCPM-V 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3 +python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build for Linux or Mac + +```bash +make +make llama-minicpmv-cli +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` + +### Video +Install FFmpeg +``` +brew install ffmpeg +brew install pkg-config +``` + +### Android + +#### Build on Android device using Termux +We found that build on Android device would bring better runtime performance, so we recommend to build on device. + +[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required). + +Install tools in Termux: +``` +apt update && apt upgrade -y +apt install git make cmake +``` + +It's recommended to move your model inside the `~/` directory for best performance: +``` +cd storage/downloads +mv model.gguf ~/ +``` + +#### Building the Project using Android NDK +Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. + +Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: + +```bash +mkdir build-android +cd build-android +export NDK=/your_ndk_path +cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. +make +``` + +Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). + +Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: + +(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) +``` +$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/ +$cd /data/data/com.termux/files/home/bin +$chmod +x ./* +``` + +Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/` + +``` +$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/ +$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/ +``` + +Now, you can start chatting: +``` +$cd /data/data/com.termux/files/home/bin +$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 54aa822c9..9b890571e 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -20,6 +20,10 @@ #include "ggml-cann.h" #endif +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -81,6 +85,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_USE_GELU "clip.use_gelu" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" @@ -211,13 +216,19 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { @@ -526,6 +537,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + int minicpmv_version = 2; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -641,7 +653,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (ctx->has_minicpmv_projector) { int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + if (ctx->minicpmv_version == 2) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); + } + else if (ctx->minicpmv_version == 3) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -768,8 +785,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - - } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -949,10 +966,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } { // attention - const int hidden_size = 4096; + int hidden_size = 4096; const int d_head = 128; - const int n_head = hidden_size/d_head; - const int num_query = 96; + int n_head = hidden_size/d_head; + int num_query = 96; + if (ctx->minicpmv_version == 2) { + hidden_size = 4096; + n_head = hidden_size/d_head; + num_query = 96; + } + else if (ctx->minicpmv_version == 3) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -1091,7 +1118,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } } - clip_ctx * new_clip = new clip_ctx; + clip_ctx * new_clip = new clip_ctx{}; // update projector type { @@ -1125,6 +1152,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: CLIP using CANN backend\n", __func__); #endif +#ifdef GGML_USE_VULKAN + new_clip->backend = ggml_backend_vk_init(0); + LOG_TEE("%s: CLIP using Vulkan backend\n", __func__); +#endif if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); @@ -1149,6 +1180,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); } + idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION); + if (idx != -1) { + new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); + } + // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search GGML_ASSERT(new_clip->has_vision_encoder); @@ -1587,7 +1623,7 @@ static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* } } -inline float clip(float x, float lower, float upper) { +inline int clip(int x, int lower, int upper) { return std::max(lower, std::min(x, upper)); } @@ -1791,10 +1827,6 @@ static std::pair uhd_get_refine_size(std::pair original_size return refine_size; } -inline int clip(int x, int lower, int upper) { - return std::max(lower, std::min(x, upper)); -} - static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { std::vector candidate_split_grids_nums; for (int i : {multiple - 1, multiple, multiple + 1}) { @@ -1910,10 +1942,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { - if (clip_is_minicpmv(ctx)) { - std::vector> imgs = uhd_slice_image(img); + + if(clip_is_minicpmv(ctx)){ + int max_slice_nums = 9; + std::vector> imgs = uhd_slice_image(img, max_slice_nums); res_imgs->size = 0; - for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t i = 0; i < imgs.size(); ++i){ res_imgs->size += imgs[i].size(); } res_imgs->data = new clip_image_f32[res_imgs->size]; @@ -2146,7 +2180,12 @@ int clip_n_patches(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - n_patches = 96; + if (ctx->minicpmv_version == 2) { + n_patches = 96; + } + else if (ctx->minicpmv_version == 3) { + n_patches = 64; + } } return n_patches; @@ -2282,6 +2321,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + if(ctx->load_image_size==nullptr){ + ctx->load_image_size= clip_image_size_init(); + } + const int pos_w = ctx->load_image_size->width/patch_size; + const int pos_h = ctx->load_image_size->height/patch_size; { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); @@ -2316,8 +2360,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - for (int i = 0; i < num_positions; i++) { - positions_data[i] = std::floor(70.0*i/num_positions); + int bucket_coords_h[70]; + int bucket_coords_w[70]; + for (int i = 0; i < pos_h; i++){ + bucket_coords_h[i] = std::floor(70.0*i/pos_h); + } + for (int i = 0; i < pos_w; i++){ + bucket_coords_w[i] = std::floor(70.0*i/pos_w); + } + for (int i = 0, id = 0; i < pos_h; i++){ + for (int j = 0; j < pos_w; j++){ + positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } } ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); @@ -2328,12 +2382,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/Qwen/Qwen-VL/tree/main // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); - if(ctx->load_image_size==nullptr){ - ctx->load_image_size= clip_image_size_init(); - } - int pos_w = ctx->load_image_size->width/patch_size; - int pos_h = ctx->load_image_size->height/patch_size; int embed_dim = 4096; + if (ctx->minicpmv_version == 2) { + embed_dim = 4096; + } + else if (ctx->minicpmv_version == 3) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2346,7 +2401,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); free(pos_embed_data); } - } else { + } + else{ { if (ctx->has_class_embedding) { struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); @@ -2548,13 +2604,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_3_b->ne[0]; } if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - return 4096; + if (ctx->minicpmv_version == 2) { + return 4096; + } + else if (ctx->minicpmv_version == 3) { + return 3584; + } } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } -bool clip_is_minicpmv(const struct clip_ctx * ctx) { - return ctx->has_minicpmv_projector; +int clip_is_minicpmv(const struct clip_ctx * ctx) { + if (ctx->has_minicpmv_projector) { + return ctx->minicpmv_version; + } + return 0; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 2ff4d3992..78588bdf1 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); -CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); #ifdef __cplusplus } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 570b2f116..25feec5c7 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -129,14 +129,14 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para if (!params->image.empty()) { LOG_TEE("using base64 encoded image instead of command line image path\n"); } - embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt); + embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt); if (!embed) { LOG_TEE("%s: can't load image from prompt\n", __func__); return NULL; } params->prompt = remove_image_from_prompt(prompt); } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str()); + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str()); if (!embed) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); return NULL; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 916d9dc40..851af0f00 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->width = img_res_v.data[i].nx; load_image_size->height = img_res_v.data[i].ny; clip_add_load_image_size(ctx_clip, load_image_size); - const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + bool encoded = false; + int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); + if (has_minicpmv_projector == 2) { + encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + } + else if (has_minicpmv_projector == 3) { + encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + } if (!encoded) { LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f951b57b2..f500ea5b9 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e std::string system_prompt; int idx = 0; int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); - system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); + if (has_minicpmv_projector == 2) { + system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; + } + else if (has_minicpmv_projector == 3) { + system_prompt = "<|im_start|>user\n"; + } LOG_TEE("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); @@ -174,7 +180,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ auto ctx_clip = clip_init_context(params); - auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str()); + auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str()); if (!embeds) { std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; return NULL; @@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; - if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; + int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); + if (!is_first) { + if (has_minicpmv_projector == 2) { + user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; + } + else if (has_minicpmv_projector == 3) { + user_prompt = "<|im_start|>user\n" + prompt; + } + } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); - eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + if (has_minicpmv_projector == 2) { + eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); + } + else if (has_minicpmv_projector == 3) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } + // generate the response LOG_TEE("\n"); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index 12cdd1281..ea773742a 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -1,9 +1,416 @@ -import argparse +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model. """ +# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes + + import os +import math +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import ( + logging, +) +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +): + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + denom = fan_in + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.self_attn = ( + SiglipAttention(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + +import argparse import json import re -import torch import numpy as np from gguf import * from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig @@ -94,6 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) # with proper args = ap.parse_args() @@ -135,6 +543,15 @@ if args.use_f32: # model = CLIPModel.from_pretrained(dir_model) # processor = CLIPProcessor.from_pretrained(dir_model) +minicpmv_version = args.minicpmv_version +emb_dim = 4096 +if minicpmv_version == 1: + emb_dim = 2304 +elif minicpmv_version == 2: + emb_dim = 4096 +elif minicpmv_version == 3: + emb_dim = 3584 + default_vision_config = { "hidden_size": 1152, "image_size": 980, @@ -144,8 +561,12 @@ default_vision_config = { "num_hidden_layers": 27, "patch_size": 14, } + vision_config = Idefics2VisionConfig(**default_vision_config) model = Idefics2VisionTransformer(vision_config) +if minicpmv_version == 3: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -158,6 +579,7 @@ fname_middle = None has_text_encoder = True has_vision_encoder = True has_minicpmv_projector = False + if args.text_only: fname_middle = "text-" has_vision_encoder = False @@ -165,6 +587,7 @@ elif args.minicpmv_projector is not None: fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True + minicpmv_version = 3 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -189,6 +612,7 @@ elif has_minicpmv_projector: fout.add_description("image encoder for MiniCPM-V") # add projector type fout.add_string("clip.projector_type", "resampler") + fout.add_int32("clip.minicpmv_version", minicpmv_version) else: fout.add_description("two-tower CLIP model") @@ -274,11 +698,11 @@ def _replace_name_resampler(s, v): if re.match("resampler.pos_embed", s): return { s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), } if re.match("resampler.proj", s): return { - re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), + re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), } if re.match("resampler.attn.in_proj_.*", s): diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 2b6bce7cf..748ff5c57 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -4,7 +4,7 @@ import torch from transformers import AutoModel, AutoTokenizer ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") +ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights @@ -29,7 +29,6 @@ if len(clip_tensors) > 0: f.write("{}\n") config = model.llm.config -config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5" config.auto_map = { "AutoConfig": "configuration_minicpm.MiniCPMConfig", "AutoModel": "modeling_minicpm.MiniCPMModel", @@ -40,7 +39,6 @@ config.auto_map = { model.llm.save_pretrained(f"{args.model}/model") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) tok.save_pretrained(f"{args.model}/model") -# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py") print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a66..2c05afb04 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -221,6 +221,40 @@ int main(int argc, char ** argv) { return 1; } + LOG("%s: llama threadpool init = n_threads = %d\n", + __func__, + (int) params.cpuparams.n_threads + ); + struct ggml_threadpool_params tpp_batch = + ggml_threadpool_params_from_cpu_params(params.cpuparams_batch); + struct ggml_threadpool_params tpp = + ggml_threadpool_params_from_cpu_params(params.cpuparams); + + set_process_priority(params.cpuparams.priority); + + struct ggml_threadpool * threadpool_batch = NULL; + if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) { + threadpool_batch = ggml_threadpool_new(&tpp_batch); + if (!threadpool_batch) { + LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads); + exit(1); + } + + // Start the non-batch threadpool in the paused state + tpp.paused = true; + } + + struct ggml_threadpool * threadpool = ggml_threadpool_new(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + + llama_attach_threadpool(ctx, threadpool, threadpool_batch); + if (ctx_guidance) { + llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch); + } + const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG("n_ctx: %d\n", n_ctx); @@ -267,9 +301,9 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_should_add_bos_token(model); + const bool add_bos = llama_add_bos_token(model); if (!llama_model_has_encoder(model)) { - GGML_ASSERT(llama_add_eos_token(model) != 1); + GGML_ASSERT(!llama_add_eos_token(model)); } LOG("add_bos: %d\n", add_bos); @@ -989,6 +1023,9 @@ int main(int argc, char ** argv) { llama_sampling_free(ctx_sampling); llama_backend_free(); + ggml_threadpool_free(threadpool); + ggml_threadpool_free(threadpool_batch); + #ifndef LOG_DISABLE_LOGS LOG_TEE("Log end\n"); #endif // LOG_DISABLE_LOGS diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f09..484dd5891 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -340,8 +340,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); @@ -480,8 +480,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -1733,8 +1733,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); + const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); + GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); diff --git a/examples/quantize/README.md b/examples/quantize/README.md index 553c2701b..5d1e11c67 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -34,7 +34,7 @@ Run the quantized model: ```bash # start inference on a gguf model -./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -n 128 +./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant" ``` When running the larger models, make sure you have enough disk space to store all the intermediate files. diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7312309ae..202346310 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -104,7 +104,7 @@ static void usage(const char * executable) { printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --keep-split: will generate quatized model in the same shards as input"); + printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); printf("Note: --include-weights and --exclude-weights cannot be used together\n"); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 65b19ce71..aab9d8105 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -253,6 +253,8 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } + struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; while (true) { @@ -260,7 +262,6 @@ int main(int argc, char ** argv) { std::getline(std::cin, query); std::vector query_tokens = llama_tokenize(ctx, query, true); - struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); @@ -293,6 +294,7 @@ int main(int argc, char ** argv) { } // clean up + llama_batch_free(query_batch); llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe8..805e05b4a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -247,6 +247,51 @@ logging: --log-append Don't truncate the old log file. ``` +Available environment variables (if specified, these variables will override parameters specified in arguments): + +- `LLAMA_CACHE`: cache directory, used by `--hf-repo` +- `HF_TOKEN`: Hugging Face access token, used when accessing a gated model with `--hf-repo` +- `LLAMA_ARG_MODEL`: equivalent to `-m` +- `LLAMA_ARG_MODEL_URL`: equivalent to `-mu` +- `LLAMA_ARG_MODEL_ALIAS`: equivalent to `-a` +- `LLAMA_ARG_HF_REPO`: equivalent to `--hf-repo` +- `LLAMA_ARG_HF_FILE`: equivalent to `--hf-file` +- `LLAMA_ARG_THREADS`: equivalent to `-t` +- `LLAMA_ARG_CTX_SIZE`: equivalent to `-c` +- `LLAMA_ARG_N_PARALLEL`: equivalent to `-np` +- `LLAMA_ARG_BATCH`: equivalent to `-b` +- `LLAMA_ARG_UBATCH`: equivalent to `-ub` +- `LLAMA_ARG_N_GPU_LAYERS`: equivalent to `-ngl` +- `LLAMA_ARG_THREADS_HTTP`: equivalent to `--threads-http` +- `LLAMA_ARG_CHAT_TEMPLATE`: equivalent to `--chat-template` +- `LLAMA_ARG_N_PREDICT`: equivalent to `-n` +- `LLAMA_ARG_ENDPOINT_METRICS`: if set to `1`, it will enable metrics endpoint (equivalent to `--metrics`) +- `LLAMA_ARG_ENDPOINT_SLOTS`: if set to `0`, it will **disable** slots endpoint (equivalent to `--no-slots`). This feature is enabled by default. +- `LLAMA_ARG_EMBEDDINGS`: if set to `1`, it will enable embeddings endpoint (equivalent to `--embeddings`) +- `LLAMA_ARG_FLASH_ATTN`: if set to `1`, it will enable flash attention (equivalent to `-fa`) +- `LLAMA_ARG_CONT_BATCHING`: if set to `0`, it will **disable** continuous batching (equivalent to `--no-cont-batching`). This feature is enabled by default. +- `LLAMA_ARG_DEFRAG_THOLD`: equivalent to `-dt` +- `LLAMA_ARG_HOST`: equivalent to `--host` +- `LLAMA_ARG_PORT`: equivalent to `--port` + +Example usage of docker compose with environment variables: + +```yml +services: + llamacpp-server: + image: ghcr.io/ggerganov/llama.cpp:server + ports: + - 8080:8080 + volumes: + - ./models:/models + environment: + # alternatively, you can use "LLAMA_ARG_MODEL_URL" to download the model + LLAMA_ARG_MODEL: /models/my_model.gguf + LLAMA_ARG_CTX_SIZE: 4096 + LLAMA_ARG_N_PARALLEL: 2 + LLAMA_ARG_ENDPOINT_METRICS: 1 # to disable, either remove or set to 0 + LLAMA_ARG_PORT: 8080 +``` ## Build @@ -368,15 +413,16 @@ node index.js ## API Endpoints -### GET `/health`: Returns the current state of the server +### GET `/health`: Returns heath check result - - 503 -> `{"status": "loading model"}` if the model is still being loaded. - - 500 -> `{"status": "error"}` if the model failed to load. - - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below. - - 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available. - - 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available. +**Response format** - If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. +- HTTP status code 503 + - Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}` + - Explanation: the model is still being loaded. +- HTTP status code 200 + - Body: `{"status": "ok" }` + - Explanation: the model is successfully loaded and the server is ready. ### POST `/completion`: Given a `prompt`, it returns the predicted completion. @@ -639,10 +685,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte }' ``` -### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. +### GET `/slots`: Returns the current slots processing state + +This endpoint can be disabled with `--no-slots` + +If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots. **Response format** +Example: + ```json [ { @@ -702,7 +754,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte ] ``` -### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled: +Possible values for `slot[i].state` are: +- `0`: SLOT_STATE_IDLE +- `1`: SLOT_STATE_PROCESSING + +### GET `/metrics`: Prometheus compatible metrics exporter + +This endpoint is only accessible if `--metrics` is set. Available metrics: - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. @@ -767,6 +825,10 @@ Available metrics: ### GET `/lora-adapters`: Get list of all LoRA adapters +This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` + +By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` + If an adapter is disabled, the scale will be set to 0. **Response format** diff --git a/examples/server/public/index.js b/examples/server/public/index.js index 670960939..fe615ca25 100644 --- a/examples/server/public/index.js +++ b/examples/server/public/index.js @@ -1 +1 @@ -const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function s(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function f(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}f.prototype.brand=t;f.prototype.h=function(){return!0};f.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};f.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};f.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};f.prototype.valueOf=function(){return this.value};f.prototype.toString=function(){return this.value+""};f.prototype.toJSON=function(){return this.value};f.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(f.prototype,"value",{get(){const t=s(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new f(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){f.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new f).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}f.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){f.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=s(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,D,T,M={},F=[],A=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,V=Array.isArray;function W(t,n){for(var e in n)t[e]=n[e];return t}function L(t){var n=t.parentNode;n&&n.removeChild(t)}function O(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return R(t,r,_,i,null)}function R(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function I(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,s,f){var c,h,a,p,d,v=_&&_.__k||F,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?R(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i)?(i.__=t,i.__b=t.__b+1,u=Z(i,e,r,f),i.__i=u,o=null,-1!==u&&(f--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u===r+1?c++:u>r?f>l-r?c+=u-r:c--:u(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),R(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+T++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=[],(_={})[n]=this,this.getChildContext=function(){return _},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.some((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.push(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e.splice(e.indexOf(t),1),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=F.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=W({},this.state),"function"==typeof t&&(t=t(W({},e),this.props)),t&&W(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),D=et(!0),T=0;var at,pt,dt,vt,yt=0,mt=[],gt=[],bt=S,kt=bt.__b,wt=bt.__r,St=bt.diffed,xt=bt.__c,Ct=bt.unmount,Ut=bt.__;function Et(t,n){bt.__h&&bt.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({__V:gt}),e.__[t]}function Ht(t){return yt=1,Pt(zt,t)}function Pt(t,n,e){var _=Et(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):zt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Nt(t,n){var e=Et(at++,3);!bt.__s&&Bt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function $t(t,n){var e=Et(at++,4);!bt.__s&&Bt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function Dt(t){return yt=5,Mt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,$t((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Mt(t,n){var e=Et(at++,7);return Bt(e.__H,n)?(e.__V=t(),e.i=n,e.__h=t,e.__V):e.__}function Ft(t,n){return yt=8,Mt((function(){return t}),n)}function At(t){var n=pt.context[t.__c],e=Et(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function Vt(t,n){bt.useDebugValue&&bt.useDebugValue(n?n(t):t)}function Wt(t){var n=Et(at++,10),e=Ht();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Et(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(jt),t.__H.__h.forEach(qt),t.__H.__h=[]}catch(n){t.__H.__h=[],bt.__e(n,t.__v)}}bt.__b=function(t){pt=null,kt&&kt(t)},bt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ut&&Ut(t,n)},bt.__r=function(t){wt&&wt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.__V=gt,t.__N=t.i=void 0}))):(n.__h.forEach(jt),n.__h.forEach(qt),n.__h=[],at=0)),dt=pt},bt.diffed=function(t){St&&St(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===bt.requestAnimationFrame||((vt=bt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.__V!==gt&&(t.__=t.__V),t.i=void 0,t.__V=gt}))),dt=pt=null},bt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(jt),t.__h=t.__h.filter((function(t){return!t.__||qt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],bt.__e(r,t.__v)}})),xt&&xt(t,n)},bt.unmount=function(t){Ct&&Ct(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{jt(t)}catch(t){n=t}})),e.__H=void 0,n&&bt.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function jt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function qt(t){var n=pt;t.__c=t.__(),pt=n}function Bt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function zt(t,n){return"function"==typeof n?n(t):n}function Gt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Jt,Kt;function Qt(t){if(Kt)Kt();Kt=t&&t.S()}function Xt({data:t}){const n=Zt(t);n.value=t;const e=Mt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Xt.displayName="_st";Object.defineProperties(f.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Xt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});Gt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof f){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});Gt("__r",(t,n)=>{Qt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Jt=_;Qt(e);t(n)});Gt("__e",(t,n,e,_)=>{Qt();Jt=void 0;t(n,e,_)});Gt("diffed",(t,n)=>{Qt();Jt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Yt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Yt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}Gt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});Gt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Zt(t){return Mt(()=>c(t),[])}function tn(t){const n=Dt(t);n.current=t;Jt.__$f|=4;return Mt(()=>v(()=>n.current()),[])}function nn(t){const n=Dt(t);n.current=t;Nt(()=>k(()=>n.current()),[])}var en=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][s+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var rn=on.bind(O);export{q as Component,j as Fragment,f as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,O as createElement,I as createRef,k as effect,O as h,rn as html,ft as hydrate,C as isValidElement,S as options,st as render,c as signal,Y as toChildArray,o as untracked,Ft as useCallback,tn as useComputed,At as useContext,Vt as useDebugValue,Nt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,$t as useLayoutEffect,Mt as useMemo,Pt as useReducer,Dt as useRef,Zt as useSignal,nn as useSignalEffect,Ht as useState}; +const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function f(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function s(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}s.prototype.brand=t;s.prototype.h=function(){return!0};s.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};s.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};s.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};s.prototype.valueOf=function(){return this.value};s.prototype.toString=function(){return this.value+""};s.prototype.toJSON=function(){return this.value};s.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(s.prototype,"value",{get(){const t=f(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new s(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){s.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new s).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}s.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){s.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=f(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,T,D,M={},F=[],A=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,W=Array.isArray;function L(t,n){for(var e in n)t[e]=n[e];return t}function O(t){var n=t.parentNode;n&&n.removeChild(t)}function R(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return I(t,r,_,i,null)}function I(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function V(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,f,s){var c,h,a,p,d,v=_&&_.__k||F,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?I(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i)?(i.__=t,i.__b=t.__b+1,u=Z(i,e,r,s),i.__i=u,o=null,-1!==u&&(s--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u==r-1?c--:u==r+1?c++:u>r?s>l-r?c+=u-r:c--:u(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),I(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+D++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=[],(_={})[n]=this,this.getChildContext=function(){return _},this.componentWillUnmount=function(){e=null},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.some((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.push(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e&&e.splice(e.indexOf(t),1),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=F.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=L({},this.state),"function"==typeof t&&(t=t(L({},e),this.props)),t&&L(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),T=et(!0),D=0;var at,pt,dt,vt,yt=0,mt=[],gt=S,bt=gt.__b,kt=gt.__r,wt=gt.diffed,St=gt.__c,xt=gt.unmount,Ct=gt.__;function Ut(t,n){gt.__h&>.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({}),e.__[t]}function Et(t){return yt=1,Ht(Bt,t)}function Ht(t,n,e){var _=Ut(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):Bt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Pt(t,n){var e=Ut(at++,3);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function Nt(t,n){var e=Ut(at++,4);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function $t(t){return yt=5,Dt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,Nt((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Dt(t,n){var e=Ut(at++,7);return qt(e.__H,n)&&(e.__=t(),e.__H=n,e.__h=t),e.__}function Mt(t,n){return yt=8,Dt((function(){return t}),n)}function Ft(t){var n=pt.context[t.__c],e=Ut(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function At(t,n){gt.useDebugValue&>.useDebugValue(n?n(t):t)}function Wt(t){var n=Ut(at++,10),e=Et();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Ut(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(Vt),t.__H.__h.forEach(jt),t.__H.__h=[]}catch(n){t.__H.__h=[],gt.__e(n,t.__v)}}gt.__b=function(t){pt=null,bt&&bt(t)},gt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ct&&Ct(t,n)},gt.__r=function(t){kt&&kt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.i=t.__N=void 0}))):(n.__h.forEach(Vt),n.__h.forEach(jt),n.__h=[],at=0)),dt=pt},gt.diffed=function(t){wt&&wt(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===gt.requestAnimationFrame||((vt=gt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.i=void 0}))),dt=pt=null},gt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(Vt),t.__h=t.__h.filter((function(t){return!t.__||jt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],gt.__e(r,t.__v)}})),St&&St(t,n)},gt.unmount=function(t){xt&&xt(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{Vt(t)}catch(t){n=t}})),e.__H=void 0,n&>.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function Vt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function jt(t){var n=pt;t.__c=t.__(),pt=n}function qt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function Bt(t,n){return"function"==typeof n?n(t):n}function zt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Gt,Jt;function Kt(t){if(Jt)Jt();Jt=t&&t.S()}function Qt({data:t}){const n=Yt(t);n.value=t;const e=Dt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Qt.displayName="_st";Object.defineProperties(s.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Qt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});zt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof s){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});zt("__r",(t,n)=>{Kt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Gt=_;Kt(e);t(n)});zt("__e",(t,n,e,_)=>{Kt();Gt=void 0;t(n,e,_)});zt("diffed",(t,n)=>{Kt();Gt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Xt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Xt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}zt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});zt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Yt(t){return Dt(()=>c(t),[])}function Zt(t){const n=$t(t);n.current=t;Gt.__$f|=4;return Dt(()=>v(()=>n.current()),[])}function tn(t){const n=$t(t);n.current=t;Pt(()=>k(()=>n.current()),[])}var nn=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][f+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var on=_n.bind(R);export{q as Component,j as Fragment,s as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,R as createElement,V as createRef,k as effect,R as h,on as html,st as hydrate,C as isValidElement,S as options,ft as render,c as signal,Y as toChildArray,o as untracked,Mt as useCallback,Zt as useComputed,Ft as useContext,At as useDebugValue,Pt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,Nt as useLayoutEffect,Dt as useMemo,Ht as useReducer,$t as useRef,Yt as useSignal,tn as useSignalEffect,Et as useState}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c25338f57..cc938e80d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,6 +15,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -67,7 +69,6 @@ enum slot_command { enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed }; enum server_task_type { @@ -693,8 +694,8 @@ struct server_context { n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_should_add_bos_token(model); - has_eos_token = llama_add_eos_token(model) != 1; + add_bos_token = llama_add_bos_token(model); + has_eos_token = !llama_add_eos_token(model); return true; } @@ -1322,7 +1323,7 @@ struct server_context { return json { {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, + {"n_predict", slot.n_predict}, // Server configured n_predict {"model", params.model_alias}, {"seed", slot.sparams.seed}, {"temperature", slot.sparams.temp}, @@ -1344,7 +1345,7 @@ struct server_context { {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict + {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, {"ignore_eos", ignore_eos}, @@ -1852,6 +1853,8 @@ struct server_context { llama_lora_adapters_apply(ctx, lora_adapters); server_task_result result; result.id = task.id; + result.stop = true; + result.error = false; result.data = json{{ "success", true }}; queue_results.send(result); } break; @@ -2036,7 +2039,7 @@ struct server_context { slot.t_start_generation = 0; if (slot.infill) { - const bool add_bos = llama_should_add_bos_token(model); + const bool add_bos = llama_add_bos_token(model); bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); @@ -2504,6 +2507,9 @@ int main(int argc, char ** argv) { return 1; } + // parse arguments from environment variables + gpt_params_parse_from_env(params); + // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0; @@ -2528,8 +2534,8 @@ int main(int argc, char ** argv) { }); LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, + {"n_threads", params.cpuparams.n_threads}, + {"n_threads_batch", params.cpuparams_batch.n_threads}, {"total_threads", std::thread::hardware_concurrency()}, {"system_info", llama_print_system_info()}, }); @@ -2554,19 +2560,19 @@ int main(int argc, char ** argv) { svr->set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { + // Access-Control-Allow-Origin is already set by middleware res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "application/json; charset=utf-8"); + return res.set_content("", "text/html"); // blank response, no data }); svr->set_logger(log_server_request); auto res_error = [](httplib::Response & res, json error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(), "application/json; charset=utf-8"); + res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; @@ -2596,11 +2602,6 @@ int main(int argc, char ** argv) { svr->set_read_timeout (params.timeout_read); svr->set_write_timeout(params.timeout_write); - if (!svr->bind_to_port(params.hostname, params.port)) { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); - return 1; - } - std::unordered_map log_data; log_data["hostname"] = params.hostname; @@ -2616,35 +2617,6 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - params.chat_template = "chatml"; - } - } - - // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", { - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } - // // Middlewares // @@ -2688,8 +2660,6 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - // TODO: make another middleware for CORS related logic - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -2697,8 +2667,21 @@ int main(int argc, char ** argv) { return false; }; + auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + return false; + } + return true; + }; + // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } if (!middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } @@ -2709,62 +2692,15 @@ int main(int argc, char ** argv) { // Route handlers (or controllers) // - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data.at("idle"); - const int n_processing_slots = result.data.at("processing"); - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (params.endpoint_slots && req.has_param("include_slots")) { - health["slots"] = result.data.at("slots"); - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } break; - case SERVER_STATE_ERROR: - { - res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); - } break; - } + const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + // error and loading states are handled by middleware + json health = {{"status", "ok"}}; + res.set_content(health.dump(), "application/json"); }; - const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2782,13 +2718,22 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - res.set_content(result.data.at("slots").dump(), "application/json"); + // optionally return "fail_on_no_slot" error + const int n_idle_slots = result.data.at("idle"); + if (req.has_param("fail_on_no_slot")) { + if (n_idle_slots == 0) { + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return; + } + } + + res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2913,7 +2858,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2943,7 +2888,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2963,13 +2908,11 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::string id_slot_str = req.path_params.at("id_slot"); int id_slot; @@ -2993,7 +2936,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -3002,7 +2945,6 @@ int main(int argc, char ** argv) { curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); } } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, @@ -3010,7 +2952,7 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() } }; - res.set_content(data.dump(), "application/json; charset=utf-8"); + res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3019,8 +2961,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3031,7 +2971,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3094,9 +3034,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -3105,12 +3043,12 @@ int main(int argc, char ** argv) { {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta} + {"meta", ctx_server.model_meta()} }, }} }; - res.set_content(models.dump(), "application/json; charset=utf-8"); + res.set_content(models.dump(), MIMETYPE_JSON); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3118,8 +3056,6 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3134,7 +3070,7 @@ int main(int argc, char ** argv) { if (!result.error && result.stop) { json result_oai = format_final_response_oaicompat(data, result.data, completion_id); - res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3196,8 +3132,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3208,7 +3142,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3256,7 +3190,6 @@ int main(int argc, char ** argv) { }; const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; @@ -3265,11 +3198,10 @@ int main(int argc, char ** argv) { tokens = ctx_server.tokenize(body.at("content"), add_special); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; @@ -3279,12 +3211,10 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); bool is_openai = false; @@ -3330,11 +3260,10 @@ int main(int argc, char ** argv) { json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - return res.set_content(root.dump(), "application/json; charset=utf-8"); + return res.set_content(root.dump(), MIMETYPE_JSON); }; - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { auto & la = ctx_server.lora_adapters[i]; @@ -3344,13 +3273,11 @@ int main(int argc, char ** argv) { {"scale", la.scale}, }); } - res.set_content(result.dump(), "application/json"); + res.set_content(result.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const std::vector body = json::parse(req.body); int max_idx = ctx_server.lora_adapters.size(); @@ -3378,7 +3305,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; @@ -3454,35 +3381,75 @@ int main(int argc, char ** argv) { log_data["n_threads_http"] = std::to_string(params.n_threads_http); svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - LOG_INFO("HTTP server listening", log_data); + // clean up function, to be called before exit + auto clean_up = [&svr]() { + svr->stop(); + llama_backend_free(); + }; - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; + // bind HTTP listen port, run the HTTP server in a thread + if (!svr->bind_to_port(params.hostname, params.port)) { + LOG_ERROR("couldn't bind HTTP server socket", { + {"hostname", params.hostname}, + {"port", params.port}, + }); + clean_up(); + LOG_ERROR("exiting due to HTTP server error", {}); + return 1; + } + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + + LOG_INFO("HTTP server is listening", log_data); + + // load the model + LOG_INFO("loading model", log_data); + if (!ctx_server.load_model(params)) { + clean_up(); + t.join(); + LOG_ERROR("exiting due to model loading error", {}); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INFO("model loaded", {}); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (params.chat_template.empty()) { + if (!ctx_server.validate_model_chat_template()) { + LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + params.chat_template = "chatml"; + } } - return 0; - }); + // print sample chat example to make it clear which template is used + { + LOG_INFO("chat template", { + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, + }); + } - ctx_server.queue_tasks.on_new_task(std::bind( - &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); + ctx_server.queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + ctx_server.queue_tasks.start_loop(); + } #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -3498,12 +3465,8 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - ctx_server.queue_tasks.start_loop(); - - svr->stop(); + clean_up(); t.join(); - llama_backend_free(); - return 0; } diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 6705a34fc..1ba7b60b6 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -205,27 +205,20 @@ def step_start_server(context): async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): match expecting_status: case 'healthy': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30) case 'ready' | 'idle': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30, - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=context.n_slots, - slots_processing=0, - expected_slots=[{'id': slot_id, 'state': 0} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30, + params={'fail_on_no_slot': 1}, + slots_idle=context.n_slots, + slots_processing=0) case 'busy': - await wait_for_health_status(context, context.base_url, 503, - 'no slot available', - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=0, - slots_processing=context.n_slots, - expected_slots=[{'id': slot_id, 'state': 1} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 503, + params={'fail_on_no_slot': 1}, + slots_idle=0, + slots_processing=context.n_slots) case _: assert False, "unknown status" @@ -1187,17 +1180,15 @@ async def gather_tasks_results(context): return n_completions -async def wait_for_health_status(context, - base_url, - expected_http_status_code, - expected_health_status, - timeout=3, - params=None, - slots_idle=None, - slots_processing=None, - expected_slots=None): +async def wait_for_slots_status(context, + base_url, + expected_http_status_code, + timeout=3, + params=None, + slots_idle=None, + slots_processing=None): if context.debug: - print(f"Starting checking for health for expected_health_status={expected_health_status}") + print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}") interval = 0.5 counter = 0 if 'GITHUB_ACTIONS' in os.environ: @@ -1205,26 +1196,19 @@ async def wait_for_health_status(context, async with aiohttp.ClientSession() as session: while True: - async with await session.get(f'{base_url}/health', params=params) as health_response: - status_code = health_response.status - health = await health_response.json() + async with await session.get(f'{base_url}/slots', params=params) as slots_response: + status_code = slots_response.status + slots = await slots_response.json() if context.debug: - print(f"HEALTH - response for expected health status='{expected_health_status}' on " - f"'{base_url}/health'?{params} is {health}\n") - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) - return - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) + print(f"slots responses {slots}\n") + if status_code == 503 and status_code == expected_http_status_code: return + if status_code == 200 and status_code == expected_http_status_code: + n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots) + n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots) + if ((slots_idle is None or slots_idle == n_slots_idle) + and (slots_processing is None or slots_processing == n_slots_processing)): + return await asyncio.sleep(interval) counter += interval @@ -1238,7 +1222,7 @@ async def wait_for_health_status(context, if n_completions > 0: return - assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}' + assert False, f'slots check timeout exceeded {counter}s>={timeout}' def assert_embeddings(embeddings): diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b051a18f1..1616edecb 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -73,10 +73,11 @@ int main(int argc, char ** argv) { // load the draft model params.model = params.model_draft; params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.n_threads_draft > 0) { - params.n_threads = params.n_threads_draft; + if (params.draft_cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.draft_cpuparams.n_threads; } - params.n_threads_batch = params.n_threads_batch_draft; + + params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; llama_init_result llama_init_dft = llama_init_from_gpt_params(params); model_dft = llama_init_dft.model; ctx_dft = llama_init_dft.context; diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index 17f5e4961..c817be566 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -362,7 +362,7 @@ int main(int raw_argc, char ** raw_argv) { prompt = stdin_buffer.str(); } - const bool model_wants_add_bos = llama_should_add_bos_token(model); + const bool model_wants_add_bos = llama_add_bos_token(model); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; diff --git a/flake.lock b/flake.lock index f9e1548a2..cc1ebe299 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1723175592, - "narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=", + "lastModified": 1724224976, + "narHash": "sha256-Z/ELQhrSd7bMzTO8r7NZgi9g5emh+aRKoCdaAv5fiO0=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5e0ca22929f3342b19569b21b2f3462f053e497b", + "rev": "c374d94f1536013ca8e92341b540eba4c22f9c62", "type": "github" }, "original": { diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 434c13b34..0dff47d65 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -7,8 +7,8 @@ extern "C" { #endif typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; -typedef struct ggml_backend_buffer * ggml_backend_buffer_t; -typedef struct ggml_backend * ggml_backend_t; +typedef struct ggml_backend_buffer * ggml_backend_buffer_t; +typedef struct ggml_backend * ggml_backend_t; // Tensor allocator struct ggml_tallocr { diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 5f3f1e286..e497b6d02 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -63,6 +63,7 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // "offset" refers to the offset of the tensor data for setting/getting data GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); @@ -102,6 +103,7 @@ extern "C" { GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); // Create a backend buffer from an existing pointer diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1d2a35402..5233a9995 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -220,7 +220,7 @@ #include #define GGML_FILE_MAGIC 0x67676d6c // "ggml" -#define GGML_FILE_VERSION 1 +#define GGML_FILE_VERSION 2 #define GGML_QNT_VERSION 2 // bump this on quantization format changes #define GGML_QNT_VERSION_FACTOR 1000 // do not change this @@ -231,6 +231,8 @@ #define GGML_MAX_SRC 10 #ifndef GGML_MAX_NAME #define GGML_MAX_NAME 64 +#define GGML_MAX_N_THREADS 512 + #endif #define GGML_MAX_OP_PARAMS 64 #define GGML_DEFAULT_N_THREADS 4 @@ -453,6 +455,8 @@ extern "C" { GGML_OP_SQR, GGML_OP_SQRT, GGML_OP_LOG, + GGML_OP_SIN, + GGML_OP_COS, GGML_OP_SUM, GGML_OP_SUM_ROWS, GGML_OP_MEAN, @@ -490,9 +494,11 @@ extern "C" { GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, + GGML_OP_IM2COL_BACK, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, + GGML_OP_POOL_2D_BACK, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, GGML_OP_ARANGE, @@ -624,6 +630,29 @@ extern "C" { // If it returns true, the computation is aborted typedef bool (*ggml_abort_callback)(void * data); + // Scheduling priorities + enum ggml_sched_priority { + GGML_SCHED_PRIO_NORMAL, + GGML_SCHED_PRIO_MEDIUM, + GGML_SCHED_PRIO_HIGH, + GGML_SCHED_PRIO_REALTIME + }; + + // Threadpool params + // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults + struct ggml_threadpool_params { + bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct ggml_threadpool; // forward declaration, see ggml.c + + typedef struct ggml_threadpool * ggml_threadpool_t; + // the compute plan that needs to be prepared for ggml_graph_compute() // since https://github.com/ggerganov/ggml/issues/287 struct ggml_cplan { @@ -631,6 +660,7 @@ extern "C" { uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` int n_threads; + struct ggml_threadpool * threadpool; // abort ggml_graph_compute when true ggml_abort_callback abort_callback; @@ -969,6 +999,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // return scalar GGML_API struct ggml_tensor * ggml_sum( struct ggml_context * ctx, @@ -1566,34 +1612,49 @@ extern "C" { float min, float max); + // im2col + // converts data into a format that effectively results in a convolution when combined with matrix multiplication GGML_API struct ggml_tensor * ggml_im2col( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum ggml_type dst_type); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D, + enum ggml_type dst_type); + + GGML_API struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // gradient of im2col output + int64_t * ne, // shape of im2col input + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1, // dilation dimension 1 + bool is_2D); GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data int s0, // stride int p0, // padding int d0); // dilation @@ -1602,29 +1663,29 @@ extern "C" { // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) GGML_API struct ggml_tensor* ggml_conv_1d_ph( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s, - int d); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s, // stride + int d); // dilation GGML_API struct ggml_tensor * ggml_conv_transpose_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation GGML_API struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 // kernel size is a->ne[0] x a->ne[1] @@ -1686,6 +1747,18 @@ extern "C" { float p0, float p1); + GGML_API struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, // "a"/input used in forward pass + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + // nearest interpolate // multiplies ne0 and ne1 by scale factor // used in stable-diffusion @@ -1760,7 +1833,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float logit_softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, @@ -1777,10 +1851,8 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq); + struct ggml_tensor * sx, + struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, @@ -1789,8 +1861,7 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq); + struct ggml_tensor * C); // partition into non-overlapping windows with padding if needed // example: @@ -2012,10 +2083,23 @@ extern "C" { GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); + GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads); + GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params *p, int n_threads); + GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params *p0, const struct ggml_threadpool_params *p1); + GGML_API struct ggml_threadpool* ggml_threadpool_new (struct ggml_threadpool_params * params); + GGML_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); + GGML_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool); + GGML_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); + GGML_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); + // ggml_graph_plan() has to be called before ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data - GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); - GGML_API enum ggml_status ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + GGML_API struct ggml_cplan ggml_graph_plan( + const struct ggml_cgraph * cgraph, + int n_threads, /* = GGML_DEFAULT_N_THREADS */ + struct ggml_threadpool * threadpool /* = NULL */ ); + GGML_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + // same as ggml_graph_compute() but the work data is allocated as a part of the context // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 1775ef3cc..ec7d30825 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -549,6 +549,13 @@ if (GGML_SYCL) file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") + find_package(DNNL) + message("-- DNNL found:" ${DNNL_FOUND}) + if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) + else() + add_compile_definitions(GGML_SYCL_DNNL=0) + endif() if (WIN32) find_package(IntelSYCL REQUIRED) find_package(MKL REQUIRED) @@ -561,6 +568,9 @@ if (GGML_SYCL) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) endif() endif() + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_EXTRA_LIBS DNNL::dnnl) + endif() endif() if (GGML_RPC) @@ -1237,7 +1247,7 @@ endif() # Data types, macros and functions related to controlling CPU affinity and # some memory allocation are available on Linux through GNU extensions in libc -if (CMAKE_SYSTEM_NAME MATCHES "Linux") +if (CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Android") add_compile_definitions(_GNU_SOURCE) endif() diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 7adaadc92..332578fd4 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -337,33 +337,18 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds } size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); } size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); } size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - if (!quant_weights) { - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); - } - else { - assert(false); - return 0; - } + UNUSED(quant_weights); + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index e1651cc64..5b877db35 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -722,9 +722,11 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { #endif struct ggml_backend_cpu_context { - int n_threads; - void * work_data; - size_t work_size; + int n_threads; + ggml_threadpool_t threadpool; + + void * work_data; + size_t work_size; ggml_abort_callback abort_callback; void * abort_callback_data; @@ -759,7 +761,7 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); - cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); cpu_plan->cgraph = *cgraph; // FIXME: deep copy if (cpu_plan->cplan.work_size > 0) { @@ -796,7 +798,7 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backe GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); if (cpu_ctx->work_size < cplan.work_size) { free(cpu_ctx->work_data); @@ -873,6 +875,7 @@ ggml_backend_t ggml_backend_cpu_init(void) { } ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->threadpool = NULL; ctx->work_data = NULL; ctx->work_size = 0; ctx->abort_callback = NULL; @@ -903,6 +906,18 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { ctx->n_threads = n_threads; } +void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + ggml_threadpool_pause(ctx->threadpool); + } + ctx->threadpool = threadpool; +} + void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); @@ -1018,10 +1033,6 @@ static bool ggml_is_view_op(enum ggml_op op) { #define GGML_SCHED_MAX_BACKENDS 16 #endif -#ifndef GGML_SCHED_MAX_SPLITS -#define GGML_SCHED_MAX_SPLITS 2048 -#endif - #ifndef GGML_SCHED_MAX_SPLIT_INPUTS #define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC #endif @@ -1125,7 +1136,8 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co } #if 0 -static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only +#define GGML_SCHED_MAX_SPLITS_DEBUG 4096 +static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) #define GET_CAUSE(node) causes[hash_id(node)] #else @@ -1549,7 +1561,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); } - GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS); split = &sched->splits[i_split]; split->backend_id = node_backend_id; split->i_start = i; @@ -1865,13 +1876,14 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); - const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; + const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph + const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); - sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer = malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 682c30d45..8a844b02a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -9,8 +9,10 @@ #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/convert.cuh" #include "ggml-cuda/cpy.cuh" +#include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" @@ -29,7 +31,6 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/conv-transpose-1d.cuh" #include #include @@ -2181,6 +2182,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_SUB: + ggml_cuda_op_sub(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2267,6 +2271,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SQRT: ggml_cuda_op_sqrt(ctx, dst); break; + case GGML_OP_SIN: + ggml_cuda_op_sin(ctx, dst); + break; + case GGML_OP_COS: + ggml_cuda_op_cos(ctx, dst); + break; case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; @@ -2303,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cuda_cross_entropy_loss(ctx, dst); + break; default: return false; } @@ -2610,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); } } @@ -2853,12 +2867,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: @@ -2890,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + case GGML_OP_CROSS_ENTROPY_LOSS: + return true; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 34bc67acd..e1390a041 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) { return a + b; } +static __device__ __forceinline__ float op_sub(const float a, const float b) { + return a - b; +} + static __device__ __forceinline__ float op_mul(const float a, const float b) { return a * b; } @@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 4f63d6372..198c9ef6f 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -2,5 +2,6 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu new file mode 100644 index 000000000..a14043e70 --- /dev/null +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -0,0 +1,106 @@ +#include "common.cuh" +#include "cross-entropy-loss.cuh" +#include "sumrows.cuh" + +#include +#include + +static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; + + const int ne_tmp = WARP_SIZE*nclasses; + + extern __shared__ float tmp_all[]; + float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; + float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; + + // Each warp first loads ne_tmp logits/labels into shared memory: + for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { + const int ig = i0*nclasses + i; // ig == i global + + tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; + tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; + } + + // Each thread in the warp then calculates the cross entropy loss for a single row. + // TODO: pad in order to avoid shared memory bank conflicts. + + // Find maximum for softmax: + float max = -INFINITY; + for (int i = 0; i < nclasses; ++i) { + max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + } + + // Calculate log(softmax(logits)) which is just logits - max: + float sum = 0.0f; + for (int i = 0; i < nclasses; ++i) { + float val = tmp_logits[lane_id*nclasses + i] - max; + sum += expf(val); + tmp_logits[lane_id*nclasses + i] = val; + } + sum = logf(sum); + + // log(exp(logits - max) / sum) = (logits - max) - log(sum) + float loss = 0.0f; + for (int i = 0; i < nclasses; ++i) { + loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + } + loss = -warp_reduce_sum(loss) / (float)k; + + __syncthreads(); + + if (lane_id == 0) { + tmp_all[warp_id] = loss; + } + + __syncthreads(); + + if (warp_id != 0) { + return; + } + + loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; + loss = warp_reduce_sum(loss); + + if (lane_id != 0) { + return; + } + + dst[blockIdx.x] = loss; +} + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); + const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + + ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); + + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + + // Combine results from individual blocks: + sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream); +} diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh new file mode 100644 index 000000000..9d7b8b0f0 --- /dev/null +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93df..1fb5c09c3 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -657,11 +658,17 @@ void launch_fattn( const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500b..342f2eb66 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_logit_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = logit_softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbfa..827437ca0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_logit_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad072..448a9a905 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); @@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355f..bf5125902 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); @@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae2322242..b10d19d93 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16( const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; @@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0f..f87f33b3e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 82e8e875f..38dbf1b5e 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } -static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); @@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index e7545f83c..191db1c13 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,3 +1,5 @@ #include "common.cuh" +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); + void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9e208011..89abfc21d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) { dst[i] = sqrtf(x[i]); } +static __global__ void sin_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = sinf(x[i]); +} + +static __global__ void cos_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = cosf(x[i]); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_ sqrt_f32<<>>(x, dst, k); } +static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE; + sin_f32<<>>(x, dst, k); +} + +static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE; + cos_f32<<>>(x, dst, k); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 4cfb0479e..c610e996a 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -9,6 +9,8 @@ #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256 +#define CUDA_SIN_BLOCK_SIZE 256 +#define CUDA_COS_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 995f1934b..91b5e61b2 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -31,6 +31,8 @@ struct ggml_metal_kernel { enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_SUB, + GGML_METAL_KERNEL_TYPE_SUB_ROW, GGML_METAL_KERNEL_TYPE_MUL, GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, @@ -82,6 +84,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -205,6 +209,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_COUNT @@ -491,6 +498,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); @@ -542,6 +551,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -665,6 +676,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } @@ -765,15 +779,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_PERMUTE: case GGML_OP_CONCAT: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: - case GGML_OP_SQR: - case GGML_OP_SUM_ROWS: return true; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: @@ -803,6 +822,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx return false; } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && @@ -1050,6 +1072,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: { @@ -1073,6 +1096,7 @@ static enum ggml_status ggml_metal_graph_compute( nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1082,6 +1106,7 @@ static enum ggml_status ggml_metal_graph_compute( } else { switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; default: GGML_ABORT("fatal error"); @@ -1409,6 +1434,48 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SIN: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_COS: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SUM_ROWS: @@ -1538,6 +1605,121 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } } break; + case GGML_OP_SSM_CONV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SSM_SCAN: + { + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + struct ggml_tensor * src4 = gf->nodes[i]->src[4]; + struct ggml_tensor * src5 = gf->nodes[i]->src[5]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_MUL_MAT: { GGML_ASSERT(ne00 == ne10); @@ -2624,9 +2806,14 @@ static enum ggml_status ggml_metal_graph_compute( float scale; float max_bias; + float logit_softcap; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -2677,30 +2864,31 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3bb37d32a..f323ab5f4 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -17,7 +17,7 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, multiplication and division of two tensors +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( @@ -70,6 +70,56 @@ kernel void kernel_add( } } +kernel void kernel_sub( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + } +} + kernel void kernel_mul( device const char * src0, device const char * src1, @@ -226,6 +276,15 @@ kernel void kernel_add_row( dst[tpig] = src0[tpig] + src1[tpig % nb]; } +kernel void kernel_sub_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, @@ -358,6 +417,27 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -667,6 +747,127 @@ kernel void kernel_diag_mask_inf_8( } } +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +// TODO: optimize +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = ne10; + const int64_t ncs = ne00; + const int64_t nr = ne01; + const int64_t n_t = ne1; + const int64_t n_s = ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); + device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// TODO: optimize +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); + device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + kernel void kernel_norm( device const void * src0, device float * dst, @@ -1976,6 +2177,7 @@ typedef void (flash_attn_ext_f16_t)( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2014,6 +2216,7 @@ kernel void kernel_flash_attn_ext_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2138,19 +2341,6 @@ kernel void kernel_flash_attn_ext_f16( } simdgroup_store(mqk, ss + 8*cc, TF, 0, false); - - const short tx = tiisg%4; - const short ty = tiisg/4; - - if (mask != q) { - // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } else { - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; - } } } @@ -2162,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16( float ms[Q]; for (short j = 0; j < Q; ++j) { - const short p = tiisg; - const float m = M[j]; - const float s = ss[j*TF + p]; + + // scale and apply the logitcap / mask + float s = ss[j*TF + tiisg]*scale; + + if (logit_softcap != 0.0f) { + s = logit_softcap*precise::tanh(s); + } + + if (mask != q) { + // mqk = mqk + mask*slope + s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; + } smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); @@ -2176,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; + ss[j*TF + tiisg] = vs; } // create a QxQ diagonal matrix for rescaling the output @@ -2345,6 +2544,7 @@ kernel void kernel_flash_attn_ext_vec_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2479,7 +2679,13 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask*slope if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; ss4[cc] = mqk; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d5b91c2db..48b90f01b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q8_K_ref(x, y, k); } -//===================================== Dot ptoducts ================================= +//===================================== Dot products ================================= // // Helper functions diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index 7757615f5..8f9d0a460 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of // RPC commands enum rpc_cmd { - ALLOC_BUFFER = 0, - GET_ALIGNMENT, - GET_MAX_SIZE, - BUFFER_GET_BASE, - FREE_BUFFER, - BUFFER_CLEAR, - SET_TENSOR, - GET_TENSOR, - COPY_TENSOR, - GRAPH_COMPUTE, - GET_DEVICE_MEMORY, + RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_GET_ALIGNMENT, + RPC_CMD_GET_MAX_SIZE, + RPC_CMD_BUFFER_GET_BASE, + RPC_CMD_FREE_BUFFER, + RPC_CMD_BUFFER_CLEAR, + RPC_CMD_SET_TENSOR, + RPC_CMD_GET_TENSOR, + RPC_CMD_COPY_TENSOR, + RPC_CMD_GRAPH_COMPUTE, + RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_COUNT, }; // RPC data structures @@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.empty()); delete ctx; @@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | base_ptr (8 bytes) | @@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); std::vector output; - bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output); GGML_ASSERT(status); } @@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == size); // output serialization format: | data (size bytes) | @@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b memcpy(input.data(), &rpc_src, sizeof(rpc_src)); memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output); GGML_ASSERT(status); // output serialization format: | result (1 byte) | GGML_ASSERT(output.size() == 1); @@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output); GGML_ASSERT(status); } @@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer memcpy(input.data(), &size, sizeof(size)); std::vector output; auto sock = get_socket(buft_ctx->endpoint); - bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | alignment (8 bytes) | @@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | max_size (8 bytes) | @@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t serialize_graph(cgraph, input); std::vector output; auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -636,7 +637,7 @@ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const } GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { return false; } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; @@ -678,6 +679,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const } auto sock = get_socket(endpoint); if (sock == nullptr) { + fprintf(stderr, "Failed to connect to %s\n", endpoint); return nullptr; } size_t alignment = get_alignment(sock); @@ -719,7 +721,7 @@ static void get_device_memory(const std::shared_ptr & sock, size_t * f // input serialization format: | 0 bytes | std::vector input; std::vector output; - bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | free (8 bytes) | total (8 bytes) | @@ -1098,59 +1100,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre if (!recv_data(sockfd, &cmd, 1)) { break; } + if (cmd >= RPC_CMD_COUNT) { + // fail fast if the command is invalid + fprintf(stderr, "Unknown command: %d\n", cmd); + break; + } std::vector input; std::vector output; uint64_t input_size; if (!recv_data(sockfd, &input_size, sizeof(input_size))) { break; } - input.resize(input_size); + try { + input.resize(input_size); + } catch (const std::bad_alloc & e) { + fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size); + break; + } if (!recv_data(sockfd, input.data(), input_size)) { break; } bool ok = true; switch (cmd) { - case ALLOC_BUFFER: { + case RPC_CMD_ALLOC_BUFFER: { ok = server.alloc_buffer(input, output); break; } - case GET_ALIGNMENT: { + case RPC_CMD_GET_ALIGNMENT: { server.get_alignment(output); break; } - case GET_MAX_SIZE: { + case RPC_CMD_GET_MAX_SIZE: { server.get_max_size(output); break; } - case BUFFER_GET_BASE: { + case RPC_CMD_BUFFER_GET_BASE: { ok = server.buffer_get_base(input, output); break; } - case FREE_BUFFER: { + case RPC_CMD_FREE_BUFFER: { ok = server.free_buffer(input); break; } - case BUFFER_CLEAR: { + case RPC_CMD_BUFFER_CLEAR: { ok = server.buffer_clear(input); break; } - case SET_TENSOR: { + case RPC_CMD_SET_TENSOR: { ok = server.set_tensor(input); break; } - case GET_TENSOR: { + case RPC_CMD_GET_TENSOR: { ok = server.get_tensor(input, output); break; } - case COPY_TENSOR: { + case RPC_CMD_COPY_TENSOR: { ok = server.copy_tensor(input, output); break; } - case GRAPH_COMPUTE: { + case RPC_CMD_GRAPH_COMPUTE: { ok = server.graph_compute(input, output); break; } - case GET_DEVICE_MEMORY: { + case RPC_CMD_GET_DEVICE_MEMORY: { // output serialization format: | free (8 bytes) | total (8 bytes) | output.resize(2*sizeof(uint64_t), 0); memcpy(output.data(), &free_mem, sizeof(free_mem)); @@ -1203,8 +1215,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free return; } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + fflush(stdout); rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); + fflush(stdout); } #ifdef _WIN32 WSACleanup(); diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index d8eb86c2c..0d884f89a 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -38,6 +38,7 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" +#include "ggml-sycl/gemm.hpp" bool ggml_sycl_loaded(void); void ggml_sycl_free_data(struct ggml_tensor * tensor); @@ -893,43 +894,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -template -static void im2col_kernel(const float *x, T *dst, int offset_delta, - int IW, int IH, int OW, int KW, int KH, - int pelements, int CHW, int s0, int s1, int p0, - int p1, int d0, int d1, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (i >= pelements) { - return; - } - - const int ksize = OW * (KH > 1 ? KW : 1); - const int kx = i / ksize; - const int kd = kx * ksize; - const int ky = (i - kd) / OW; - const int ix = i % OW; - - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1; - - const int64_t offset_dst = - (item_ct1.get_group(1) * OW + ix) * CHW + - (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = - sycl::vec(0.0f) - .convert()[0]; - } else { - const int64_t offset_src = item_ct1.get_group(0) * offset_delta; - dst[offset_dst] = - sycl::vec(x[offset_src + iih * IW + iiw]) - .convert()[0]; - } -} - template static void pool2d_nchw_kernel( const int ih, const int iw, const int oh, const int ow, @@ -1742,32 +1706,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, }); } -template -static void im2col_sycl(const float *x, T *dst, int IW, int IH, - int OW, int OH, int KW, int KH, int IC, - int offset_delta, int s0, int s1, int p0, - int p1, int d0, int d1, - queue_ptr stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; - sycl::range<3> block_nums(IC, OH, num_blocks); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * - sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH, - parallel_elements, (IC * KH * KW), s0, s1, p0, - p1, d0, d1, item_ct1); - }); - } -} - - static bool g_sycl_loaded = false; bool ggml_sycl_loaded(void) { @@ -2545,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, @@ -2554,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl( dpct::library_data_t::real_half))); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); +#else + auto dnnl_stream = ctx.stream_dnnl(stream); + DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), + src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#endif } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); @@ -2576,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; - +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +#else + auto dnnl_stream = ctx.stream_dnnl(stream); + DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), + src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); +#endif } (void) dst; (void) src1_ddq_i; @@ -2636,47 +2587,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens (void) src1_dd; } -inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; - - const int64_t IC = src1->ne[is_2D ? 2 : 1]; - const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; - - const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; - - const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; - - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - - if (dst->type == GGML_TYPE_F16) { - im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); - } else { - im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); - } - - (void) src0; - (void) src0_dd; -} - inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -3581,7 +3491,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE + && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE); bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 58dd9c9a6..d21b5f8dd 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -25,5 +25,6 @@ #include "norm.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "im2col.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index e878f4f50..cf5291b31 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try { << ", line:" << __LINE__ << std::endl; std::exit(1); } + +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { + const int64_t max_range = std::numeric_limits::max(); + int64_t sycl_down_blk_size = block_size; + int64_t global_range = accumulate_block_num * sycl_down_blk_size; + while(global_range > max_range) { + sycl_down_blk_size /= 2; + global_range = accumulate_block_num * sycl_down_blk_size; + } + return sycl_down_blk_size; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 86d8b40e8..05947ccb7 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,6 +19,10 @@ #include "dpct/helper.hpp" #include "ggml-sycl.h" #include "presets.hpp" +#if GGML_SYCL_DNNL +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" +#endif #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL @@ -130,6 +134,7 @@ typedef sycl::float2 dfloat2; #endif // GGML_SYCL_F16 #define MMVQ_MAX_BATCH_SIZE 8 +#define MMVQ_MIN_BATCH_SIZE 4 static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -276,6 +281,52 @@ struct ggml_backend_sycl_context { return stream(device, 0); } +#if GGML_SYCL_DNNL + dnnl::engine make_engine(sycl::queue* q) { + // Get the device associated with the queue + sycl::device dev = q->get_device(); + // Get the context associated with the queue + sycl::context ctx = q->get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + return eng; + } + + std::unordered_map stream_map; + std::unordered_map engine_map; + dnnl::stream stream_dnnl(int device, int _stream) { + auto q = stream(device, _stream); + return stream_dnnl(q); + } + dnnl::engine engine_dnnl(sycl::queue* qptr) { + auto it = engine_map.find(qptr); + if (it == engine_map.end()) { + auto eng = make_engine(qptr); + engine_map[qptr] = eng; + return eng; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl(sycl::queue* qptr) { + auto it = stream_map.find(qptr); + if (it == stream_map.end()) { + auto eng = engine_dnnl(qptr); + auto stream = dnnl::sycl_interop::make_stream(eng, *qptr); + stream_map[qptr] = stream; + return stream; + } + else + { + return it->second; + } + } + dnnl::stream stream_dnnl() { + return stream_dnnl(device, 0); + } +#endif + // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; @@ -352,4 +403,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { return acc.template get_multi_ptr().get(); } +int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 39c28753c..5fd15e6cd 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -3,19 +3,19 @@ #include "presets.hpp" template -static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, +static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { - const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + + const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)); if (i >= k) { return; } - const int ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int64_t ib = i/qk; // block index + const int64_t iqs = (i%qk)/qr; // quant index + const int64_t iybs = i - i%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; @@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ template static void dequantize_block_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int k, + dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); + const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx, } template -static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb32 = k / 32; - const int nb = (k + 255) / 256; + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb32 = k / 32; - const int nb = (k + 255) / 256; + const int64_t nb32 = k / 32; + const int64_t nb = (k + 255) / 256; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, template -static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 { dpct::has_capability_or_fail(stream->get_device(), @@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k, template -static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = k / QK_K; + const int64_t nb = k / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = (k + QK_K - 1) / QK_K; + const int64_t nb = (k + QK_K - 1) / QK_K; #if QK_K == 64 dequantize_row_iq4_nl_sycl(vx, y, k, stream); #else @@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k, } template -static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, +static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { - const int nb = (k + QK_K - 1) / QK_K; + const int64_t nb = (k + QK_K - 1) / QK_K; { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); @@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k, } template -static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, +static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + // make each work-item deal with more elements since sycl global range can not exceed max int const src_t * x = (src_t *) vx; - - y[i] = x[i]; + for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) { + y[i] = x[i]; + } } template static void convert_unary_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int k, + dst_t *__restrict__ y, const int64_t k, dpct::queue_ptr stream) { - const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + + // decrease global range when it exceeds the max int + int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE); + sycl::range<3> block_nums(1, 1, num_blocks); + sycl::range<3> local_range(1, 1, local_size); { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); stream->parallel_for( - sycl::nd_range<3>( - sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { convert_unary(vx, y, k, item_ct1); }); diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index b1f10d635..0ce2874aa 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -17,7 +17,7 @@ template using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, - int k, dpct::queue_ptr stream); + int64_t k, dpct::queue_ptr stream); typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index ed8ad098b..8f4041fff 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -15,9 +15,9 @@ #include "common.hpp" -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); -static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_1 * x = (const block_q4_1 *) vx; @@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_1 * x = (const block_q5_1 *) vx; @@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, +static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, } template -static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri } template -static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, +static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -203,14 +203,14 @@ template static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_q2_K * x = (const block_q2_K *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; dst_t * y = yy + i*QK_K + 128*n; @@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); #else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); dst_t * y = yy + i*QK_K + 16*is + il; @@ -239,19 +239,19 @@ template static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_q3_K * x = (const block_q3_K *) vx; #if QK_K == 256 - const int r = item_ct1.get_local_id(2) / 4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); - const int n = tid / 4; - const int j = tid - 4*n; + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; + int64_t is = 8*n + 2*j + is0; int shift = 2*j; int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : @@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); #else - const int tid = item_ct1.get_local_id(2); - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const int64_t im = il/8; // 0...1 + const int64_t in = il%8; // 0...7 dst_t * y = yy + i*QK_K + 16*is + il; @@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { const block_q4_K * x = (const block_q4_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t is = 2*il; + const int64_t n = 4; dst_t * y = yy + i*QK_K + 64*il + n*ir; @@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri y[l +32] = d2 * (q_vec[l] >> 4) - m2; } #else - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; @@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri const sycl::nd_item<3> &item_ct1) { const block_q5_K * x = (const block_q5_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = item_ct1.get_local_id(2); - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 dst_t * y = yy + i*QK_K + 64*il + 2*ir; @@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; #else - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 + const int64_t im = tid/8; // 0...3 + const int64_t in = tid%8; // 0...7 + const int64_t is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; dst_t * y = yy + i*QK_K + tid; @@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri const sycl::nd_item<3> &item_ct1) { const block_q6_K * x = (const block_q6_K *) vx; - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = item_ct1.get_local_id(2); - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; dst_t * y = yy + i*QK_K + 128*ip + il; @@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #else // assume 32 threads - const int tid = item_ct1.get_local_id(2); - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 dst_t * y = yy + i*QK_K + 16*ip + il; @@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res const uint8_t *ksigns_iq2xs_ptr, const uint8_t *kmask_iq2xs_ptr) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_xxs * x = (const block_iq2_xxs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * aux8 = (const uint8_t *)q2; @@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest const uint8_t *ksigns_iq2xs, const uint8_t *kmask_iq2xs) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_xs * x = (const block_iq2_xs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); @@ -504,13 +504,13 @@ __dpct_inline__ static void dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq2_s * x = (const block_iq2_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; @@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res const uint8_t *ksigns_iq2xs, const uint8_t *kmask_iq2xs) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq3_xxs * x = (const block_iq3_xxs *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * q3 = x[i].qs + 8*ib; const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; @@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq3_s * x = (const block_iq3_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); @@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint32_t *iq1s_grid_gpu) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); @@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1, const uint32_t *iq1s_grid_gpu) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq1_m * x = (const block_iq1_m *) vx; - const int tid = item_ct1.get_local_id(2); + const int64_t tid = item_ct1.get_local_id(2); #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * sc = (const uint16_t *)x[i].scales; iq1m_scale_t scale; @@ -656,12 +656,12 @@ __dpct_inline__ static void dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[ib].qs + 4*il; const float d = (float)x[ib].d; @@ -678,12 +678,12 @@ template __dpct_inline__ static void dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_group(2); + const int64_t i = item_ct1.get_group(2); const block_iq4_xs * x = (const block_iq4_xs *)vx; - const int tid = item_ct1.get_local_id(2); - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index ae45630e1..5c343822f 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -4,7 +4,7 @@ #include "presets.hpp" -static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; // automatic half -> float type cast if dfloat == float @@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v.y() = x[ib + iqs + 1]; } -static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; // automatic half -> float type cast if dfloat == float diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp new file mode 100644 index 000000000..2ad9b36f4 --- /dev/null +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -0,0 +1,101 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GEMM_HPP +#define GGML_SYCL_GEMM_HPP + +#include +#include + +#include "ggml-sycl.h" + +#if GGML_SYCL_DNNL + +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" + +class DnnlGemmWrapper { +public: + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + + template + static constexpr dt to_dt() { + if constexpr (std::is_same_v) return dt::f32; + else if constexpr (std::is_same_v) return dt::f16; + else static_assert(0); + } + + static inline void row_gemm(sycl::queue& q, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + // Get the device associated with the queue + sycl::device dev = q.get_device(); + // Get the context associated with the queue + sycl::context ctx = q.get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } + + + static inline void row_gemm(const dnnl::stream& stream, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + auto const eng = stream.get_engine(); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } +}; + +#endif + +#endif // GGML_SYCL_GEMM_HPP diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp new file mode 100644 index 000000000..6a0a0fcd0 --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -0,0 +1,125 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "im2col.hpp" + +template +static void im2col_kernel( + const float *x, T *dst, int64_t batch_offset, int64_t offset_delta, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1, + const sycl::nd_item<3> &item_ct1) { + const int64_t work_group_size = item_ct1.get_local_range(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + + // make each work-item deal with more elements since sycl global range can not exceed max int + for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) { + + const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; + + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; + + const int64_t iiw = ix * s0 + kx * d0 - p0; + const int64_t iih = oh * s1 + ky * d1 - p1; + + const int64_t offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = + sycl::vec(0.0f) + .convert()[0]; + } else { + const int64_t offset_src = ic * offset_delta + batch * batch_offset; + dst[offset_dst] = + sycl::vec(x[offset_src + iih * IW + iiw]) + .convert()[0]; + } + } +} + +template +static void im2col_sycl( + const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, + queue_ptr stream) { + const int64_t parallel_elements = OW * KW * KH; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + + // decrease global range when it exceeds the max int + int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + sycl::range<3> block_nums(batch * IC, OH, num_blocks); + sycl::range<3> local_range(1, 1, local_size); + + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * local_range, local_range), + [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, + parallel_elements, (IC * KH * KW), s0, s1, p0, + p1, d0, d1, item_ct1); + }); + } +} + +void ggml_sycl_op_im2col( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[3]; + const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + + if (dst->type == GGML_TYPE_F16) { + im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } else { + im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } + + (void) src0; + (void) src0_dd; +} diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp new file mode 100644 index 000000000..7db144fbb --- /dev/null +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -0,0 +1,23 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_IM2COL_HPP +#define GGML_SYCL_IM2COL_HPP + +#include "common.hpp" + +void ggml_sycl_op_im2col( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream); + +#endif // GGML_SYCL_IM2COL_HPP diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 7a0ec706f..ca4f44cf7 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -180,6 +180,7 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16; vk_pipeline pipeline_mul_f32; vk_pipeline pipeline_div_f32; @@ -187,6 +188,8 @@ struct vk_device_struct { vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32; @@ -1687,6 +1690,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -1699,6 +1704,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3971,6 +3978,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; case GGML_OP_ADD: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_add_f32; @@ -4015,6 +4027,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sqr_f32; } return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -4163,6 +4185,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: @@ -4373,6 +4397,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_MUL: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: @@ -4463,6 +4489,28 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, }, dryrun); } +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + d_offset, + 0.0f, 0.0f, offset, + }, dryrun); +} + static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -4568,6 +4616,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + }); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + }); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); @@ -5621,12 +5695,15 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CPY: @@ -5668,6 +5745,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); @@ -5700,6 +5781,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SQR: ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); @@ -5808,6 +5897,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * switch (tensor->op) { case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_GET_ROWS: case GGML_OP_MUL: case GGML_OP_DIV: @@ -5815,6 +5905,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CPY: @@ -6539,12 +6631,15 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_CONT: @@ -6987,6 +7082,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_CLAMP) { tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_PAD) { @@ -6995,6 +7094,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 88e4fb732..dc6cdca0b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -69,23 +69,42 @@ int ggml_sve_cnt_b = 0; #endif #include +#if !defined(__clang__) typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; typedef atomic_int atomic_flag; #define ATOMIC_FLAG_INIT 0 +typedef enum { + memory_order_relaxed, + memory_order_consume, + memory_order_acquire, + memory_order_release, + memory_order_acq_rel, + memory_order_seq_cst +} memory_order; + static void atomic_store(atomic_int * ptr, LONG val) { InterlockedExchange(ptr, val); } +static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { + // TODO: add support for explicit memory order + InterlockedExchange(ptr, val); +} static LONG atomic_load(atomic_int * ptr) { return InterlockedCompareExchange(ptr, 0, 0); } +static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedCompareExchange(ptr, 0, 0); +} static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { return InterlockedExchangeAdd(ptr, inc); } -static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); +static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedExchangeAdd(ptr, inc); } static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { return InterlockedExchange(ptr, 1); @@ -93,6 +112,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { static void atomic_flag_clear(atomic_flag * ptr) { InterlockedExchange(ptr, 0); } +#else // clang +#include +#endif typedef HANDLE pthread_t; @@ -121,8 +143,10 @@ static int sched_yield (void) { return 0; } #else + #include #include +#include typedef void * thread_ret_t; @@ -1868,28 +1892,102 @@ struct ggml_context_container { struct ggml_context context; }; -struct ggml_compute_state_shared { - const struct ggml_cgraph * cgraph; - const struct ggml_cplan * cplan; +// +// Threading defs +// - int n_threads; +typedef pthread_t ggml_thread_t; + +#if defined(_WIN32) + +typedef CONDITION_VARIABLE ggml_cond_t; +typedef SRWLOCK ggml_mutex_t; + +#define ggml_mutex_init(m) InitializeSRWLock(m) +#define ggml_mutex_destroy(m) +#define ggml_mutex_lock(m) AcquireSRWLockExclusive(m) +#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) +#define ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) +#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) + +#define ggml_cond_init(c) InitializeConditionVariable(c) +#define ggml_cond_destroy(c) +#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) +#define ggml_cond_broadcast(c) WakeAllConditionVariable(c) + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +typedef pthread_cond_t ggml_cond_t; +typedef pthread_mutex_t ggml_mutex_t; + +#define ggml_mutex_init(m) pthread_mutex_init(m, NULL) +#define ggml_mutex_destroy(m) pthread_mutex_destroy(m) +#define ggml_mutex_lock(m) pthread_mutex_lock(m) +#define ggml_mutex_unlock(m) pthread_mutex_unlock(m) +#define ggml_mutex_lock_shared(m) pthread_mutex_lock(m) +#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define ggml_lock_lock(x) _mm_pause() +#else +#define ggml_lock_lock(x) UNUSED(x) +#endif +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 +#define ggml_cond_init(c) pthread_cond_init(c, NULL) +#define ggml_cond_destroy(c) pthread_cond_destroy(c) +#define ggml_cond_wait(c, m) pthread_cond_wait(c, m) +#define ggml_cond_broadcast(c) pthread_cond_broadcast(c) + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +// Threadpool def +struct ggml_threadpool { + ggml_mutex_t mutex; // mutex for cond.var + ggml_cond_t cond; // cond.var for waiting for new work + + struct ggml_cgraph * cgraph; + struct ggml_cplan * cplan; // synchronization primitives + atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int n_barrier; atomic_int n_barrier_passed; + atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. - ggml_abort_callback abort_callback; // abort ggml_graph_compute when true - void * abort_callback_data; + // these are atomic as an annotation for thread-sanitizer + atomic_bool stop; // Used for stopping the threadpool altogether + atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads + struct ggml_compute_state * workers; // per thread state + int n_threads_max; // number of threads in the pool + int n_threads_cur; // number of threads used in the current graph + + int32_t prio; // Scheduling priority + uint32_t poll; // Polling level (0 - no polling) enum ggml_status ec; }; +// Per-thread state struct ggml_compute_state { +#ifndef GGML_USE_OPENMP ggml_thread_t thrd; + bool cpumask[GGML_MAX_N_THREADS]; + int last_graph; + bool pending; +#endif + struct ggml_threadpool * threadpool; int ith; - struct ggml_compute_state_shared * shared; }; struct ggml_compute_params { @@ -1900,7 +1998,7 @@ struct ggml_compute_params { size_t wsize; void * wdata; - struct ggml_compute_state_shared * shared; + struct ggml_threadpool * threadpool; }; // @@ -2310,7 +2408,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } @@ -2669,6 +2769,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, return sum; } +static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (ggml_float)expf(val); + } + return sum = (ggml_float)logf(sum); +} + inline static float ggml_silu_backward_f32(float x, float dy) { const float s = 1.0f/(1.0f + expf(-x)); return dy*s*(1.0f + x*(1.0f - s)); @@ -2760,6 +2873,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SQR", "SQRT", "LOG", + "SIN", + "COS", "SUM", "SUM_ROWS", "MEAN", @@ -2797,9 +2912,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", + "IM2COL_BACK", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", + "POOL_2D_BACK", "UPSCALE", "PAD", "ARANGE", @@ -2833,7 +2950,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2848,6 +2965,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "x^2", "√x", "log(x)", + "sin(x)", + "cos(x)", "Σx", "Σx_k", "Σx/n", @@ -2885,9 +3004,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", + "im2col_back(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", + "pool_2d_back(x)", "upscale(x)", "pad(x)", "arange(start, stop, step)", @@ -2921,7 +3042,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2948,6 +3069,19 @@ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); +// Helpers for polling loops +#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) +static inline void ggml_thread_cpu_relax(void) { + __asm__ volatile("yield" ::: "memory"); +} +#elif defined(__x86_64__) +static inline void ggml_thread_cpu_relax(void) { + _mm_pause(); +} +#else +static inline void ggml_thread_cpu_relax(void) {;} +#endif + // // NUMA support // @@ -2995,42 +3129,36 @@ inline static void ggml_critical_section_start(void) { } #ifdef GGML_USE_OPENMP -static void ggml_barrier(struct ggml_compute_state_shared * shared) { - if (shared->n_threads == 1) { +static void ggml_barrier(struct ggml_threadpool * threadpool) { + if (threadpool->n_threads_cur == 1) { return; } #pragma omp barrier } #else -static void ggml_barrier(struct ggml_compute_state_shared * shared) { - if (shared->n_threads == 1) { +static void ggml_barrier(struct ggml_threadpool * threadpool) { + if (threadpool->n_threads_cur == 1) { return; } - atomic_int * n_barrier = &shared->n_barrier; - atomic_int * n_barrier_passed = &shared->n_barrier_passed; + atomic_int * n_barrier = &threadpool->n_barrier; + atomic_int * n_barrier_passed = &threadpool->n_barrier_passed; - int n_threads = shared->n_threads; - int passed_old = atomic_load(n_barrier_passed); + int n_threads = threadpool->n_threads_cur; + int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed); if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) { // last thread atomic_store(n_barrier, 0); - atomic_fetch_add(n_barrier_passed, 1); + atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed); } else { // wait for other threads - const int n_spin_before_sleep = 100000; while (true) { - for (int i = 0; i < n_spin_before_sleep; i++) { - if (atomic_load(n_barrier_passed) != passed_old) { - return; - } - #if defined(__SSE3__) - _mm_pause(); - #endif + if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) { + return; } - sched_yield(); + ggml_thread_cpu_relax(); } } } @@ -3767,6 +3895,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( } struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + GGML_ASSERT(obj_new); // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here @@ -4486,8 +4615,6 @@ static struct ggml_tensor * ggml_add_impl( bool is_node = false; if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -4661,11 +4788,13 @@ static struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_can_repeat(b, a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -4880,6 +5009,72 @@ struct ggml_tensor * ggml_log_inplace( return ggml_log_impl(ctx, a, true); } +// ggml_sin + +static struct ggml_tensor * ggml_sin_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SIN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sin( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sin_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sin_impl(ctx, a, true); +} + +// ggml_cos + +static struct ggml_tensor * ggml_cos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_COS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cos( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cos_impl(ctx, a, true); +} + // ggml_sum struct ggml_tensor * ggml_sum( @@ -6727,17 +6922,20 @@ struct ggml_tensor * ggml_im2col( GGML_ASSERT(a->ne[2] == b->ne[2]); } else { GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(b->ne[3] == 1); } bool is_node = false; - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward + if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data is_node = true; } const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + const int64_t ne[4] = { is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, @@ -6757,6 +6955,37 @@ struct ggml_tensor * ggml_im2col( return result; } +struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + + bool is_node = false; + + if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] @@ -6770,7 +6999,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -6896,17 +7125,17 @@ struct ggml_tensor * ggml_pool_2d( bool is_node = false; if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } struct ggml_tensor * result; - const int64_t ne[3] = { + const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), a->ne[2], + a->ne[3], }; - result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); @@ -6917,6 +7146,37 @@ struct ggml_tensor * ggml_pool_2d( return result; } +struct ggml_tensor * ggml_pool_2d_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * af, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result; + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = af; + return result; +} + // ggml_upscale static struct ggml_tensor * ggml_upscale_impl( @@ -7095,7 +7355,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float logit_softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7122,7 +7383,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7142,7 +7403,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -7229,43 +7490,34 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq) { - GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_matrix(x)); + struct ggml_tensor * sx, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_3d(sx)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - GGML_ASSERT( s->ne[0] == d_conv - 1); - GGML_ASSERT( s->ne[1] == d_inner); - GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + // TODO: maybe support other strides than 1? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); bool is_node = false; - if (s->grad || x->grad || c->grad || sq->grad) { + if (sx->grad || c->grad) { GGML_ABORT("fatal error"); // TODO: implement is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = s; - result->src[1] = x; - result->src[2] = c; - result->src[3] = sq; + result->src[0] = sx; + result->src[1] = c; return result; } @@ -7279,39 +7531,42 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq) { + struct ggml_tensor * C) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tokens = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tokens); - GGML_ASSERT(C->ne[0] == d_state); - GGML_ASSERT(C->ne[1] == n_tokens); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); } bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { GGML_ABORT("fatal error"); // TODO: implement is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // concatenated y + ssm_states struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -7322,7 +7577,6 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = sq; return result; } @@ -9999,7 +10253,7 @@ static void ggml_compute_forward_acc_f32( ((char *) src0->data), ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); } const int ith = params->ith; @@ -10104,11 +10358,10 @@ static void ggml_compute_forward_sub_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } + assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int ith = params->ith; + const int nth = params->nth; const int nr = ggml_nrows(src0); @@ -10117,40 +10370,55 @@ static void ggml_compute_forward_sub_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // rows per thread + const int dr = (nr + nth - 1)/nth; + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE - vDSP_vsub( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); #else - ggml_vec_sub_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif - // } - // } + } } } else { // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; } @@ -10496,6 +10764,96 @@ static void ggml_compute_forward_log( } } +// ggml_compute_forward_sin + +static void ggml_compute_forward_sin_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sin_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sin( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sin_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_cos + +static void ggml_compute_forward_cos_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_cos_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_cos( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cos_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum static void ggml_compute_forward_sum_f32( @@ -10995,11 +11353,6 @@ static void ggml_compute_forward_concat_f32( GGML_TENSOR_BINARY_OP_LOCALS - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - const int32_t dim = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(dim >= 0 && dim < 4); @@ -12374,10 +12727,10 @@ UseGgmlGemm1:; if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(¶ms->shared->current_chunk, nth); + atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { @@ -12485,7 +12838,7 @@ UseGgmlGemm2:; break; } - current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); + current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); } } @@ -12580,7 +12933,7 @@ static void ggml_compute_forward_mul_mat_id( } } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { @@ -12734,7 +13087,7 @@ static void ggml_compute_forward_out_prod_f32( if (ith == 0) { ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); // dst[:,:,:,:] = 0 // for i2,i3: @@ -12852,7 +13205,7 @@ static void ggml_compute_forward_out_prod_q_f32( if (ith == 0) { ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); // parallelize by last three dimensions @@ -13038,7 +13391,7 @@ static void ggml_compute_forward_set_f32( ((char *) src0->data), ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); } const int ith = params->ith; @@ -13617,7 +13970,7 @@ static void ggml_compute_forward_diag_mask_f32( ((char *) src0->data), ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); } // TODO: handle transposed/permuted matrices @@ -14393,7 +14746,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14481,7 +14834,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14536,6 +14889,7 @@ static void ggml_compute_forward_conv_transpose_1d( } } +// ggml_compute_forward_im2col_f32 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14546,7 +14900,6 @@ static void ggml_compute_forward_im2col_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -14577,7 +14930,6 @@ static void ggml_compute_forward_im2col_f32( int ofs0 = is_2D ? nb13 : nb12; int ofs1 = is_2D ? nb12 : nb11; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] @@ -14613,6 +14965,7 @@ static void ggml_compute_forward_im2col_f32( } +// ggml_compute_forward_im2col_f16 // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -14708,6 +15061,99 @@ static void ggml_compute_forward_im2col( } } +// ggml_compute_forward_im2col_back_f32 + +static void ggml_compute_forward_im2col_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne3 : ne2; + const int64_t IC = is_2D ? ne2 : ne1; + const int64_t IH = is_2D ? ne1 : 1; + const int64_t IW = ne0; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne12 : 1; + const int64_t OW = ne11; + + int ofs0 = is_2D ? nb3 : nb2; + int ofs1 = is_2D ? nb2 : nb1; + + GGML_ASSERT(nb0 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + for (int64_t iih = 0; iih < IH; iih++) { + for (int64_t iiw = 0; iiw < IW; iiw++) { + + // micro kernel + float grad = 0.0f; + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + // For s0 > 1 some values were skipped over in the forward pass. + // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. + const int64_t tmpw = (iiw + p0 - ikw*d0); + if (tmpw % s0 != 0) { + continue; + } + const int64_t iow = tmpw / s0; + + // Equivalent logic as above except for s1. + int64_t ioh; + if (is_2D) { + const int64_t tmph = iih + p1 - ikh*d1; + + if (tmph % s1 != 0) { + continue; + } + + ioh = tmph / s1; + } else { + ioh = 0; + } + + if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { + continue; + } + + const float * const src_data = (const float *) src1->data + + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + } + } + float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] + dst_data[iih*IW + iiw] = grad; + } + } + } + } + } +} // ggml_compute_forward_conv_transpose_2d @@ -14768,7 +15214,7 @@ static void ggml_compute_forward_conv_transpose_2d( memset(dst->data, 0, ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); const int32_t stride = ggml_get_op_params_i32(dst, 0); @@ -14950,6 +15396,128 @@ static void ggml_compute_forward_pool_2d( } } +// ggml_compute_forward_pool_2d_back + +static void ggml_compute_forward_pool_2d_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst + + assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + char * cdata = (char *) dst->data; + const char * cdataf = (const char *) dstf->data; + const char * const data_end = cdata + ggml_nbytes(dst); + + GGML_ASSERT(params->ith == 0); + memset(cdata, 0, ggml_nbytes(dst)); + + const int64_t px = src->ne[0]; + const int64_t py = src->ne[1]; + const int64_t pa = px * py; + + const float * splane = (const float *) src->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + const float * const srow = splane + oy * px; + for (int ox = 0; ox < px; ++ox) { + const float grad0 = srow[ox]; + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + if (op == GGML_OP_POOL_MAX) { + float maxval = -FLT_MAX; + int kxmax = -1; + int kymax = -1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + const float val = dst->type == GGML_TYPE_F32 ? + ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]); + if (val <= maxval) { + continue; + } + + maxval = val; + kxmax = kx; + kymax = ky; + } + } + + if (kxmax == -1 || kymax == -1) { + continue; + } + + void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); + const int j = ix + kxmax; + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad0; + } else { + ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j])); + } + } else if (op == GGML_OP_POOL_AVG) { + const float grad = grad0 / ka; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad; + } else { + ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad); + } + } + } + } else { + GGML_ASSERT(false); + } + } + } + + cdata += dst->nb[2]; + cdataf += dst->nb[2]; + splane += pa; + } +} + // ggml_compute_forward_upscale static void ggml_compute_forward_upscale_f32( @@ -15283,11 +15851,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -15351,7 +15925,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + + s += mv; // apply mask const float Mold = M; @@ -15360,7 +15940,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -15427,7 +16007,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -15502,7 +16082,7 @@ static void ggml_compute_forward_flash_attn_back_f32( if (ith == 0) { memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); const int64_t elem_q = ggml_nelements(q); const int64_t elem_k = ggml_nelements(k); @@ -15782,27 +16362,22 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_ssm_conv_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight - const struct ggml_tensor * src3 = dst->src[3]; // state_seq + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv - const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // for use with the destination state offset between sequences - GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -15812,76 +16387,29 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { - // multiple sequences means it's hard to know when it's the first time a state is read, - // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); - // can't use memcpy because of d_conv vs d_conv - 1 + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + + // TODO: transpose the output for smaller strides for big batches? + // d_inner for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - // copy s0 to last (d_conv - 1) columns of s - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + + // d_conv + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } + x[i1] = sumf; } } } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - int ne0s0; - - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} - ne0s0 = src0->ne[0]; - } else { - // the source is the last (d_conv - 1) columns of the destination - s0 = s + 1; - ne0s0 = nc; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // shift state left - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; - } - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } - - // it seems a little faster when this is separate from the state shift - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - float sumf = 0.0f; - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; - } - x[i1] = sumf; - } - } } static void ggml_compute_forward_ssm_conv( @@ -15910,15 +16438,14 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - const struct ggml_tensor * src6 = dst->src[6]; // sq const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15927,12 +16454,12 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C, and when copying the states + // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[2]) - GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -15942,64 +16469,36 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { - // it's hard to know if the source states have already been copied - // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); - memcpy(s, s0, nc*ir*sizeof(float)); - } - } + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} - } else { - // otherwise the source is the same as the destination - s0 = s; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; - } - y[i1] = sumf; - } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } } } @@ -16274,7 +16773,7 @@ static void ggml_compute_forward_add_rel_pos_f32( if (params->ith == 0) { memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); } // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 @@ -16559,9 +17058,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( if (ith == 0) { memset(sums, 0, sizeof(float) * (nth + nth * nc)); } - ggml_barrier(params->shared); - - const double eps = 1e-9; + ggml_barrier(params->threadpool); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16583,20 +17080,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } #endif - // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max); - assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum >= 0.0); - // avoid log(0) by rescaling from [0..1] to [eps..1] - ggml_vec_scale_f32(nc, st, sum); - ggml_vec_add1_f32(nc, st, st, eps); - ggml_vec_log_f32(nc, st, st); + ggml_vec_add1_f32(nc, st, st, -sum); ggml_vec_mul_f32(nc, st, st, s1); - float st_sum = 0; + float st_sum = 0.0f; ggml_vec_sum_f32(nc, &st_sum, st); sums[ith] += st_sum; @@ -16607,7 +17099,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } #endif } - ggml_barrier(params->shared); + ggml_barrier(params->threadpool); if (ith == 0) { float * dp = (float *) dst->data; @@ -16653,8 +17145,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - const double eps = 1e-9; - // TODO: handle transposed/permuted matrices const int64_t nc = src0->ne[0]; const int64_t nr = ggml_nrows(src0); @@ -16686,11 +17176,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( ggml_vec_max_f32(nc, &max, s0); ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); - sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, ds0, 1.0/sum); // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - ggml_vec_scale_f32(nc, ds0, sum); - ggml_vec_add1_f32(nc, ds0, ds0, eps); ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); @@ -16771,6 +17259,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_log(params, tensor); } break; + case GGML_OP_SIN: + { + ggml_compute_forward_sin(params, tensor); + } break; + case GGML_OP_COS: + { + ggml_compute_forward_cos(params, tensor); + } break; case GGML_OP_SUM: { ggml_compute_forward_sum(params, tensor); @@ -16911,6 +17407,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col(params, tensor); } break; + case GGML_OP_IM2COL_BACK: + { + ggml_compute_forward_im2col_back_f32(params, tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -16923,6 +17423,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_pool_2d(params, tensor); } break; + case GGML_OP_POOL_2D_BACK: + { + ggml_compute_forward_pool_2d_back(params, tensor); + } break; case GGML_OP_UPSCALE: { ggml_compute_forward_upscale(params, tensor); @@ -17291,7 +17795,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + if (ggml_are_same_shape(src0, src1)) { + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + } else { + src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table); + } } } break; case GGML_OP_ADD1: @@ -17417,6 +17925,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_SIN: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, + tensor->grad, + ggml_cos(ctx, src0)), + zero_table); + } + } break; + case GGML_OP_COS: + { + if (src0->grad) { + src0->grad = + ggml_sub_or_set(ctx, + src0->grad, + ggml_mul(ctx, + tensor->grad, + ggml_sin(ctx, src0)), + zero_table); + } + } break; case GGML_OP_SUM: { if (src0->grad) { @@ -17864,6 +18396,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ABORT("fatal error"); // TODO: not implemented } case GGML_OP_IM2COL: + { + if (src1->grad) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + + src1->grad = ggml_add_or_set(ctx, + src1->grad, + ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), + zero_table); + } + } break; + case GGML_OP_IM2COL_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -17876,6 +18425,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ABORT("fatal error"); // TODO: not implemented } case GGML_OP_POOL_2D: + { + if (src0->grad) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), + zero_table); + } + } break; + case GGML_OP_POOL_2D_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -18165,6 +18731,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { GGML_ASSERT(gf->n_nodes > 0); + GGML_ASSERT(gf->grads); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph if (keep) { @@ -18348,65 +18915,6 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) { ggml_hash_set_reset(&cgraph->visited_hash_set); } -// -// thread data -// -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// - -#ifdef __APPLE__ - -//#include -// -//typedef os_unfair_lock ggml_lock_t; -// -//#define ggml_lock_init(x) UNUSED(x) -//#define ggml_lock_destroy(x) UNUSED(x) -//#define ggml_lock_lock os_unfair_lock_lock -//#define ggml_lock_unlock os_unfair_lock_unlock -// -//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#define ggml_lock_lock(x) UNUSED(x) -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#else - -//typedef pthread_spinlock_t ggml_lock_t; - -//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) -//#define ggml_lock_destroy pthread_spin_destroy -//#define ggml_lock_lock pthread_spin_lock -//#define ggml_lock_unlock pthread_spin_unlock - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) -#define ggml_lock_lock(x) _mm_pause() -#else -#define ggml_lock_lock(x) UNUSED(x) -#endif -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#endif - // Android's libc implementation "bionic" does not support setting affinity #if defined(__gnu_linux__) static void set_numa_thread_affinity(int thread_n) { @@ -18504,6 +19012,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -18590,6 +19100,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -18597,6 +19108,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_POOL_1D: case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: { n_tasks = 1; } break; @@ -18683,9 +19195,268 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { return n_tasks; } -struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) { +static thread_ret_t ggml_graph_compute_secondary_thread(void* data); + +#if defined(_WIN32) +#include "windows.h" + +// TODO: support > 64 CPUs +bool ggml_thread_apply_affinity(bool * mask) { + HANDLE h = GetCurrentThread(); + uint64_t bitmask = 0ULL; + + assert(GGML_MAX_N_THREADS >= 64); + + for (int32_t i = 0; i < 8; i++) { + int32_t idx = i * 8; + uint8_t val = 0; + val |= mask[idx + 0] << 0; + val |= mask[idx + 1] << 1; + val |= mask[idx + 2] << 2; + val |= mask[idx + 3] << 3; + val |= mask[idx + 4] << 4; + val |= mask[idx + 5] << 5; + val |= mask[idx + 6] << 6; + val |= mask[idx + 7] << 7; + bitmask |= (uint64_t)val << idx; + } + + for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); + break; + } + } + + DWORD_PTR m = (DWORD_PTR)bitmask; + + m = SetThreadAffinityMask(h, m); + + return m != 0; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. + // This is up to the applications. + DWORD p = THREAD_PRIORITY_NORMAL; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; + case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; + case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; + case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + if (!SetThreadPriority(GetCurrentThread(), p)) { + fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#elif defined(__APPLE__) +#include +#include + +static bool ggml_thread_apply_affinity(const bool * mask) { + // Not supported on Apple platforms + UNUSED(mask); + return true; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#else // posix? + +static bool ggml_thread_apply_affinity(const bool * mask) { + cpu_set_t cpuset; + int err; + + CPU_ZERO(&cpuset); + + for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); + CPU_SET(i, &cpuset); + } + } + +#ifdef __ANDROID__ + err = sched_setaffinity(0, sizeof(cpuset), &cpuset); + if (err < 0) { + err = errno; + } +#else + err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); +#endif + if (err != 0) { + fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); + return false; + } + + return true; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#endif + +static bool ggml_thread_cpumask_is_valid(const bool * mask) { + for (int i = 0; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { return true; } + } + return false; +} + +static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { + if (!strict) { + memcpy(local_mask, global_mask, GGML_MAX_N_THREADS); + return; + } else { + memset(local_mask, 0, GGML_MAX_N_THREADS); + int32_t base_idx = *iter; + for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + int32_t idx = base_idx + i; + if (idx >= GGML_MAX_N_THREADS) { + // Just a cheaper modulo + idx -= GGML_MAX_N_THREADS; + } + if (global_mask[idx]) { + local_mask[idx] = 1; + *iter = idx + 1; + return; + } + } + } +} + +void ggml_threadpool_free(struct ggml_threadpool* threadpool) { + if (!threadpool) return; + +#ifndef GGML_USE_OPENMP + struct ggml_compute_state* workers = threadpool->workers; + const int n_threads = threadpool->n_threads_max; + + ggml_mutex_lock(&threadpool->mutex); + + threadpool->stop = true; + threadpool->pause = false; + + ggml_cond_broadcast(&threadpool->cond); + ggml_mutex_unlock(&threadpool->mutex); + + for (int j = 1; j < n_threads; j++) { + int32_t rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED); + UNUSED(rc); + } + + ggml_mutex_destroy(&threadpool->mutex); + ggml_cond_destroy(&threadpool->cond); +#endif // GGML_USE_OPENMP + + GGML_ALIGNED_FREE(threadpool->workers); + GGML_ALIGNED_FREE(threadpool); +} + +#ifndef GGML_USE_OPENMP +// pause/resume must be called under mutex +static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) { + GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + ggml_cond_broadcast(&threadpool->cond); +} + +static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) { + GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + ggml_cond_broadcast(&threadpool->cond); +} +#endif + +void ggml_threadpool_pause(struct ggml_threadpool * threadpool) { +#ifndef GGML_USE_OPENMP + ggml_mutex_lock(&threadpool->mutex); + if (!threadpool->pause) { + ggml_threadpool_pause_locked(threadpool); + } + ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +void ggml_threadpool_resume(struct ggml_threadpool * threadpool) { +#ifndef GGML_USE_OPENMP + ggml_mutex_lock(&threadpool->mutex); + if (threadpool->pause) { + ggml_threadpool_resume_locked(threadpool); + } + ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +struct ggml_cplan ggml_graph_plan( + const struct ggml_cgraph * cgraph, + int n_threads, + struct ggml_threadpool * threadpool) { + + if (threadpool == NULL) { + GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + } if (n_threads <= 0) { - n_threads = GGML_DEFAULT_N_THREADS; + n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } size_t work_size = 0; @@ -18841,12 +19612,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } if (work_size > 0) { - work_size += CACHE_LINE_SIZE*(n_threads - 1); + work_size += CACHE_LINE_SIZE*(n_threads); } - cplan.n_threads = MIN(max_tasks, n_threads); - cplan.work_size = work_size; - cplan.work_data = NULL; + cplan.threadpool = threadpool; + cplan.n_threads = MIN(max_tasks, n_threads); + cplan.work_size = work_size; + cplan.work_data = NULL; return cplan; } @@ -18854,17 +19626,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; - const struct ggml_cgraph * cgraph = state->shared->cgraph; - const struct ggml_cplan * cplan = state->shared->cplan; + const struct ggml_cgraph * cgraph = state->threadpool->cgraph; + const struct ggml_cplan * cplan = state->threadpool->cplan; set_numa_thread_affinity(state->ith); struct ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ state->shared->n_threads, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.shared=*/ state->shared, + /*.ith =*/ state->ith, + /*.nth =*/ state->threadpool->n_threads_cur, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool=*/ state->threadpool, }; for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { @@ -18873,12 +19645,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->ec = GGML_STATUS_ABORTED; + state->threadpool->ec = GGML_STATUS_ABORTED; } - ggml_barrier(state->shared); + ggml_barrier(state->threadpool); - if (state->shared->ec != GGML_STATUS_SUCCESS) { + if (state->threadpool->ec != GGML_STATUS_SUCCESS) { break; } } @@ -18886,24 +19658,243 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { return 0; } +#ifndef GGML_USE_OPENMP + +static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + if (state->pending || threadpool->stop || threadpool->pause) { return true; } + + // check for new graph/work + int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + if (new_graph != state->last_graph) { + state->pending = (state->ith < threadpool->n_threads_cur); + state->last_graph = new_graph; + } + + return state->pending; +} + +static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + // This seems to make 0 ... 100 a decent range for polling level across modern processors. + // Perhaps, we can adjust it dynamically based on load and things. + const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; + + for (uint64_t i=0; !ggml_graph_compute_ready(state) && ipending; +} + +static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + if (ggml_graph_compute_poll_for_work(state)) { + return state->pending; + } + + ggml_mutex_lock_shared(&threadpool->mutex); + while (!ggml_graph_compute_ready(state)) { + // No new work. Wait for the signal. + GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith); + ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + ggml_mutex_unlock_shared(&threadpool->mutex); + + return state->pending; +} + +static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + struct ggml_threadpool * threadpool = state->threadpool; + + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(state->cpumask)) { + ggml_thread_apply_affinity(state->cpumask); + } + + while (true) { + // Check if we need to sleep + while (threadpool->pause) { + GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); + ggml_mutex_lock_shared(&threadpool->mutex); + if (threadpool->pause) { + ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); + ggml_mutex_unlock_shared(&threadpool->mutex); + } + + // This needs to be checked for after the cond_wait + if (threadpool->stop) break; + + // Check if there is new work + // The main thread is the only one that can dispatch new work + + ggml_graph_compute_check_for_work(state); + if (state->pending) { + state->pending = false; + + ggml_graph_compute_thread(state); + } + } + + return (thread_ret_t) 0; +} + +// Start processing new graph +static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool) +{ + // always take the mutex here because the worker threads are doing hybrid poll/wait + + ggml_mutex_lock(&threadpool->mutex); + + atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed); + + if (threadpool->pause) { + // Update main thread prio and affinity to match the threadpool settings + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + + // resume does cond broadcast + ggml_threadpool_resume_locked(threadpool); + } else { + ggml_cond_broadcast(&threadpool->cond); + } + + ggml_mutex_unlock(&threadpool->mutex); +} + +#endif // GGML_USE_OPENMP + +void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) +} + +struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { + struct ggml_threadpool_params p; + ggml_threadpool_params_init(&p, n_threads); + return p; +} + +bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; +} + +static struct ggml_threadpool * ggml_threadpool_new_impl( + struct ggml_threadpool_params * tpp, + struct ggml_cgraph * cgraph, + struct ggml_cplan * cplan) { + + struct ggml_threadpool * threadpool = + GGML_ALIGNED_MALLOC(sizeof(struct ggml_threadpool)); + { + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->n_graph = 0; + threadpool->n_barrier = 0; + threadpool->n_barrier_passed = 0; + threadpool->current_chunk = 0; + threadpool->stop = false; + threadpool->pause = tpp->paused; + threadpool->workers = NULL; + threadpool->n_threads_max = tpp->n_threads; + threadpool->n_threads_cur = tpp->n_threads; + threadpool->poll = tpp->poll; + threadpool->prio = tpp->prio; + threadpool->ec = GGML_STATUS_SUCCESS; + } + + // Allocate and init workers state + const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads; + struct ggml_compute_state * workers = GGML_ALIGNED_MALLOC(workers_size); + + memset(workers, 0, workers_size); + for (int j = 0; j < tpp->n_threads; j++) { + workers[j].threadpool = threadpool; + workers[j].ith = j; + } + + threadpool->workers = workers; + +#ifndef GGML_USE_OPENMP + ggml_mutex_init(&threadpool->mutex); + ggml_cond_init(&threadpool->cond); + + // Spin the threads for all workers, and update CPU placements. + // Place the main thread last (towards the higher numbered CPU cores). + + int32_t cpumask_iter = 0; + + for (int j = 1; j < tpp->n_threads; j++) { + ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); + + int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]); + GGML_ASSERT(rc == 0); + } + + ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); + + if (!threadpool->pause) { + // Update main thread prio and affinity at the start, otherwise we'll do it in resume + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + } +#endif // GGML_USE_OPENMP + + return threadpool; +} + +struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) { + return ggml_threadpool_new_impl(tpp, NULL, NULL); +} + enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { GGML_ASSERT(cplan); GGML_ASSERT(cplan->n_threads > 0); GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); - int n_threads = cplan->n_threads; + int n_threads = cplan->n_threads; + struct ggml_threadpool * threadpool = cplan->threadpool; - struct ggml_compute_state_shared state_shared = { - /*.cgraph =*/ cgraph, - /*.cgraph_plan =*/ cplan, - /*.n_threads =*/ n_threads, - /*.n_barrier =*/ 0, - /*.n_barrier_passed =*/ 0, - /*.abort_callback =*/ NULL, - /*.abort_callback_data =*/ NULL, - /*.current_chunk =*/ 0, - /*.ec =*/ GGML_STATUS_SUCCESS, - }; + bool disposable_threadpool = false; + + if (threadpool == NULL) { + GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + disposable_threadpool = true; + + struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads); + threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan); + } else { + // Reset some of the parameters that need resetting + // No worker threads should be accessing the parameters below at this stage + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->n_threads_cur = n_threads; + threadpool->current_chunk = 0; + threadpool->ec = GGML_STATUS_SUCCESS; + } + + if (n_threads > threadpool->n_threads_max) { + GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n"); + } #ifdef GGML_USE_OPENMP if (n_threads > 1) { @@ -18913,63 +19904,36 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl { // update the number of threads from the actual number of threads that we got from OpenMP n_threads = omp_get_num_threads(); - state_shared.n_threads = n_threads; + threadpool->n_threads_cur = n_threads; } - struct ggml_compute_state worker = { - .thrd = 0, - .ith = omp_get_thread_num(), - .shared = &state_shared, - }; - ggml_graph_compute_thread(&worker); + ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); } } else { - struct ggml_compute_state worker = { - .thrd = 0, - .ith = 0, - .shared = &state_shared, - }; - ggml_graph_compute_thread(&worker); + ggml_graph_compute_thread(&threadpool->workers[0]); } #else - struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); + // Kick all threads to start the new graph + ggml_graph_compute_kickoff(threadpool); - for (int j = 0; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - }; - } - - // create thread pool - for (int j = 1; j < n_threads; ++j) { - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - - // this is a work thread too - ggml_graph_compute_thread(&workers[0]); - - // join or kill thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; j++) { - const int rc = ggml_thread_join(workers[j].thrd, NULL); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } + // This is a work thread too + ggml_graph_compute_thread(&threadpool->workers[0]); #endif // don't leave affinity set on the main thread clear_numa_thread_affinity(); - return state_shared.ec; + enum ggml_status ret = threadpool->ec; + + if (disposable_threadpool) { + ggml_threadpool_free(threadpool); + } + + return ret; } enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { - struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); + struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); @@ -19110,9 +20074,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19142,9 +20108,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { const uint32_t type = tensor->type; const uint32_t op = tensor->op; + const int32_t flags = tensor->flags; fwrite(&type, sizeof(uint32_t), 1, fout); fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&flags, sizeof(int32_t), 1, fout); for (int j = 0; j < GGML_MAX_DIMS; ++j) { const uint64_t ne = tensor->ne[j]; @@ -19203,6 +20171,14 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { } } } + + // dump the data + // TODO: pad this to 32 byte boundary + if ((flags & GGML_TENSOR_FLAG_PARAM)) { + const size_t size = ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } } } @@ -19316,10 +20292,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_leafs; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; @@ -19337,20 +20315,19 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); - tensor->op = (enum ggml_op) op; + tensor->op = (enum ggml_op) op; + tensor->flags = flags; memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; - tensor->data = (void *) ptr; - for (int j = 0; j < GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j]; } - result->leafs[i] = tensor; + tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); - ptr += ggml_nbytes(tensor); + result->leafs[i] = tensor; fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } @@ -19362,10 +20339,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * { uint32_t type; uint32_t op; + int32_t flags; for (uint32_t i = 0; i < n_nodes; ++i) { type = *(const uint32_t *) ptr; ptr += sizeof(type); op = *(const uint32_t *) ptr; ptr += sizeof(op); + flags = *(const int32_t *) ptr; ptr += sizeof(flags); enum ggml_op eop = (enum ggml_op) op; @@ -19455,6 +20434,11 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * result->nodes[i] = tensor; + // TODO tensor data is be duplicated due to ggml_new_tensor call above + if (flags & GGML_TENSOR_FLAG_PARAM) { + tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); + } + fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); } } @@ -19723,6 +20707,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_callback callback, void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); + GGML_ASSERT(f->type == GGML_TYPE_F32); // these will store the parameters we want to optimize struct ggml_tensor * ps[GGML_MAX_PARAMS]; @@ -19764,7 +20749,7 @@ static enum ggml_opt_result ggml_opt_adam( float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads, NULL); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; @@ -20111,7 +21096,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( opt->iter = iter; } - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads, NULL); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; @@ -20489,6 +21474,8 @@ enum ggml_opt_result ggml_opt( struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f) { + GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor."); + bool free_ctx = false; if (ctx == NULL) { struct ggml_init_params params_ctx = { @@ -20543,6 +21530,8 @@ enum ggml_opt_result ggml_opt_resume_g( ggml_opt_callback callback, void * callback_data) { + GGML_ASSERT(f->grad && "ggml_set_param must be called for at least one ancestor"); + // build forward + backward compute graphs enum ggml_opt_result result = GGML_OPT_RESULT_OK; @@ -21630,6 +22619,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { void gguf_add_tensor( struct gguf_context * ctx, const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); if (gguf_find_tensor(ctx, tensor->name) != -1) { GGML_ABORT("duplicated tensor name"); } diff --git a/ggml/src/vulkan-shaders/acc.comp b/ggml/src/vulkan-shaders/acc.comp new file mode 100644 index 000000000..4c8739efe --- /dev/null +++ b/ggml/src/vulkan-shaders/acc.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)])); + } +} + diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp new file mode 100644 index 000000000..f9a858cbf --- /dev/null +++ b/ggml/src/vulkan-shaders/cos.comp @@ -0,0 +1,15 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp new file mode 100644 index 000000000..7faf9be93 --- /dev/null +++ b/ggml/src/vulkan-shaders/sin.comp @@ -0,0 +1,15 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index 53ceb13d3..0c5b7b279 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -368,6 +368,10 @@ void process_shaders(std::vector>& tasks) { string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); })); @@ -392,6 +396,14 @@ void process_shaders(std::vector>& tasks) { string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f63ec450a..b55effa99 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class Keys: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: MODEL = "tokenizer.ggml.model" @@ -219,6 +220,8 @@ class MODEL_ARCH(IntEnum): T5 = auto() T5ENCODER = auto() JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() class MODEL_TENSOR(IntEnum): @@ -347,6 +350,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1065,6 +1070,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.EXAONE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -1105,6 +1141,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.NEMOTRON: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # @@ -1333,6 +1373,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 76385a828..af3b98c67 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ class GGUFWriter: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: + self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9aa2209e2..a4f185c06 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -10,10 +10,10 @@ class TensorNameMap: # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais + "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf + "model.embed_tokens", # llama-hf nemotron "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -52,7 +52,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -62,7 +62,7 @@ class TensorNameMap: # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox - "transformer.ln_f", # gpt2 gpt-j falcon jais + "transformer.ln_f", # gpt2 gpt-j falcon jais exaone "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth "transformer.norm_f", # mpt dbrx @@ -75,6 +75,7 @@ class TensorNameMap: "transformer.rms_norm", # Grok "encoder.final_layernorm", # chatglm "transformer.norm", # openelm + "model.norm", # nemotron ), # Rope frequencies @@ -88,12 +89,12 @@ class TensorNameMap: # Attention norm MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone "transformer.blocks.{bid}.norm_1", # mpt "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf + "model.layers.{bid}.input_layernorm", # llama-hf nemotron "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -135,18 +136,19 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok + "transformer.h.{bid}.attn.attention.q_proj", # exaone ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j @@ -154,18 +156,20 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok + "transformer.h.{bid}.attn.attention.k_proj", # exaone ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok + "transformer.h.{bid}.attn.attention.v_proj", # exaone ), # Attention output @@ -175,7 +179,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j @@ -190,6 +194,7 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm + "transformer.h.{bid}.attn.attention.out_proj", # exaone ), # Attention output norm @@ -215,10 +220,10 @@ class TensorNameMap: # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais + "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -258,7 +263,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "h.{bid}.mlp.dense_h_to_4h", # bloom - "model.layers.{bid}.mlp.up_proj", # llama-hf refact + "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j @@ -277,6 +282,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm + "transformer.h.{bid}.mlp.c_fc_1", # exaone ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -308,6 +314,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic + "transformer.h.{bid}.mlp.c_fc_0", # exaone ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -329,7 +336,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "h.{bid}.mlp.dense_4h_to_h", # bloom - "model.layers.{bid}.mlp.down_proj", # llama-hf + "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j @@ -347,6 +354,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm + "model.layers.h.{bid}.mlp.c_proj", # exaone ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 19f6761e2..eea381e5a 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.1" +version = "0.10.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/include/llama.h b/include/llama.h index 3c28cf0b5..f2e701602 100644 --- a/include/llama.h +++ b/include/llama.h @@ -93,6 +93,9 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, }; enum llama_rope_type { @@ -264,9 +267,9 @@ extern "C" { enum llama_split_mode split_mode; // how to split the model across multiple GPUs // main_gpu interpretation depends on split_mode: - // LLAMA_SPLIT_NONE: the GPU that is used for the entire model - // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results - // LLAMA_SPLIT_LAYER: ignored + // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model + // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results + // LLAMA_SPLIT_MODE_LAYER: ignored int32_t main_gpu; // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() @@ -301,8 +304,8 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) - uint32_t n_threads; // number of threads to use for generation - uint32_t n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id @@ -425,6 +428,13 @@ extern "C" { //optional: LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); + // Optional: an auto threadpool gets created in ggml if not passed explicitly + LLAMA_API void llama_attach_threadpool( + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); + // Call once at the end of the program - currently only used for MPI LLAMA_API void llama_backend_free(void); @@ -508,6 +518,9 @@ extern "C" { // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns true if the model is recurrent (like Mamba, RWKV, etc.) + LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -831,13 +844,13 @@ extern "C" { // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) - LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch); // Get the number of threads used for generation of a single token. - LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx); + LLAMA_API int32_t llama_n_threads(struct llama_context * ctx); // Get the number of threads used for prompt and batch processing (multiple token). - LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); + LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx); // Set whether the model is in embeddings mode or not // If true, embeddings will be returned but logits will not @@ -912,11 +925,8 @@ extern "C" { LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding - // Returns -1 if unknown, 1 for true or 0 for false. - LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); - - // Returns -1 if unknown, 1 for true or 0 for false. - LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); + LLAMA_API bool llama_add_bos_token(const struct llama_model * model); + LLAMA_API bool llama_add_eos_token(const struct llama_model * model); // Codellama infill tokens LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index eef6768b1..1e6db754f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -797faa25af14126eb30134d4033139ae3c5428ed +28b7633d733bbeef0026570fbc61c79c5e9aa5ae diff --git a/src/llama-impl.h b/src/llama-impl.h index 399b134a7..952774096 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -31,11 +31,17 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 749f85718..323660ef5 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -321,6 +321,21 @@ private: // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused +template, typename Compare = std::less> +class llama_priority_queue : public std::priority_queue { +public: + using std::priority_queue::priority_queue; + + T pop_move() { + T item = std::move(this->c.front()); + std::pop_heap(this->c.begin(), this->c.end(), this->comp); + this->c.pop_back(); + return item; + } + + void pop() = delete; +}; + struct llm_bigram_bpe { struct comparator { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { @@ -329,7 +344,7 @@ struct llm_bigram_bpe { }; using queue_storage = std::vector; - using queue = std::priority_queue; + using queue = llama_priority_queue; llm_symbol::index left; llm_symbol::index right; std::string text; @@ -388,6 +403,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + case LLAMA_VOCAB_PRE_TYPE_EXAONE: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -410,6 +426,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_PORO: + case LLAMA_VOCAB_PRE_TYPE_BLOOM: + case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: regex_exprs = { " ?[^(\\s|.,!?…。,、।۔،)]+", }; @@ -517,8 +535,7 @@ struct llm_tokenizer_bpe { // build token(s) while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); + auto bigram = work_queue.pop_move(); auto & left_symbol = symbols[bigram.left]; auto & right_symbol = symbols[bigram.right]; @@ -1466,11 +1483,11 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) { return vocab.special_pad_id; } -int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) { +bool llama_add_bos_token_impl(const struct llama_vocab & vocab) { return vocab.tokenizer_add_bos; } -int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) { +bool llama_add_eos_token_impl(const struct llama_vocab & vocab) { return vocab.tokenizer_add_eos; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 7adfc16da..6e8f30be4 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -95,8 +95,8 @@ llama_token llama_token_sep_impl(const struct llama_vocab & vocab); llama_token llama_token_nl_impl (const struct llama_vocab & vocab); llama_token llama_token_pad_impl(const struct llama_vocab & vocab); -int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab); -int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab); +bool llama_add_bos_token_impl(const struct llama_vocab & vocab); +bool llama_add_eos_token_impl(const struct llama_vocab & vocab); llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); diff --git a/src/llama.cpp b/src/llama.cpp index 7f2f00031..2274296b4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -210,6 +210,8 @@ enum llm_arch { LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_EXAONE, LLM_ARCH_UNKNOWN, }; @@ -255,6 +257,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -324,6 +328,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, @@ -422,6 +427,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -1296,6 +1302,43 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_NEMOTRON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_EXAONE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2196,6 +2239,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -2245,6 +2289,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2328,8 +2373,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - uint32_t n_threads; // number of threads to use for generation - uint32_t n_threads_batch; // number of threads to use for batch processing + int n_threads; // number of threads to use for generation + int n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -2471,10 +2516,29 @@ struct llama_layer { struct ggml_tensor * ffn_down_scale; }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + // TODO: whole_seqs for embeddings? + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; std::set seq_id; @@ -2495,7 +2559,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token bool v_trans = true; // the value tensor is transposed @@ -2658,6 +2721,340 @@ struct llama_model { } }; +struct llama_sbatch_seq { + int32_t n_seq_id; + llama_seq_id * seq_id; + size_t offset; + size_t length; + + // helper for smoother batch API transition -- can be deprecated in the future + llama_seq_id all_seq_id; // used if seq_id == NULL +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + /*equal_seqs =*/ true, + /*n_tokens =*/ 0, + /*n_seq_tokens =*/ 0, + /*n_seqs =*/ 0, + /*token =*/ !has_embd ? ubatch_token.data() : nullptr, + /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, + /*pos =*/ ubatch_pos.data(), + /*n_seq_id =*/ ubatch_n_seq_id.data(), + /*seq_id =*/ ubatch_seq_id.data(), + /*output =*/ ubatch_output.data(), + }; + return ubatch; + } + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + (n_embd * seq.offset); + } + } else { + ubatch.embd = nullptr; + } + // from here on, the else branches are deprecated; + // they are helpers for smoother batch API transition + if (batch->pos) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + llama_pos bi = ids[seq.offset + i]; + ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + } + } + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } else { + GGML_ASSERT(seq.n_seq_id == 1); + ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; + } + } else { + // simple split + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; + } + } + } + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; + } + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); + } + + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; + } + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (simple_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + s.all_seq_id = batch.all_seq_id; + return; + } + std::sort(ids.begin(), ids.end(), + [&batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id (assuming batch.all_pos_1 is positive) + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; + seq.push_back(new_seq); + last_seq = &seq.back(); + } + } else { + llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + seq.push_back(new_seq); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); + } +}; + struct llama_context { llama_context(const llama_model & model) : model(model) @@ -2679,6 +3076,7 @@ struct llama_context { struct llama_cparams cparams; struct llama_sampling sampling; + struct llama_sbatch sbatch; struct llama_kv_cache kv_self; struct llama_control_vector cvec; @@ -2693,6 +3091,9 @@ struct llama_context { #endif ggml_backend_t backend_cpu = nullptr; + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + bool has_evaluated_once = false; int64_t t_start_us; @@ -2939,8 +3340,7 @@ static bool llama_kv_cache_init( cache.has_shift = false; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.recurrent = llama_model_is_recurrent(&model); cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; @@ -2953,13 +3353,6 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } - // count used buffer types std::map buft_layer_count; if (offload) { @@ -3027,45 +3420,161 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_batch & batch) { + const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_seqs = batch.n_seqs; + const uint32_t n_seq_tokens = batch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. - llama_seq_id min = cache.size - 1; - llama_seq_id max = 0; + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(batch.equal_seqs); - for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { - if (seq_id > max) { - max = seq_id; - } - if (seq_id < min) { - min = seq_id; - } - // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); - } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; - } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set - } else { + int32_t min = cache.size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = batch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); return false; } + if (j > 0) { + llama_kv_cell & seq = cache.cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + cache.used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(cache.size, -1); + for (uint32_t i = 0; i < cache.size; ++i) { + llama_kv_cell & cell = cache.cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < cache.size; ++i) { + if (tails_verif[i] != cache.cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = cache.head; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + llama_kv_cell & seq_meta = cache.cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cache.cells[dst_id]; + llama_kv_cell & src_cell = cache.cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cache.cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cache.cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cache.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cache.cells[seq_id].tail = cell_id; } } @@ -3074,7 +3583,7 @@ static bool llama_kv_cache_find_slot( cache.n = max - min + 1; // sanity check - return max >= min; + return cache.n >= n_seqs; } // otherwise, one cell per token. @@ -3112,11 +3621,14 @@ static bool llama_kv_cache_find_slot( } } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.cells[cache.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } @@ -3142,6 +3654,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { for (int32_t i = 0; i < (int32_t) cache.size; ++i) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + cache.cells[i].src = -1; + cache.cells[i].tail = -1; } cache.head = 0; cache.used = 0; @@ -3168,9 +3682,16 @@ static bool llama_kv_cache_seq_rm( return false; } if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; + int32_t & tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cache.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + if (p0 <= cell.pos && p1 < cell.pos) { + tail_id = -1; + } } } else { // seq_id is negative, then the range should include everything or nothing @@ -3194,6 +3715,7 @@ static bool llama_kv_cache_seq_rm( if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + cache.cells[i].src = -1; if (new_head == cache.size) new_head = i; } } @@ -3216,23 +3738,29 @@ static void llama_kv_cache_seq_cp( if (cache.recurrent) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; + llama_kv_cell & tail_src = cache.cells[seq_id_src]; + llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); - } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.delta = -1; + cell_dst.src = -1; + cache.used -= 1; + } } + if (tail_src.tail >= 0) { + llama_kv_cell & cell_src = cache.cells[tail_src.tail]; - cache.do_copy = true; - - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } } + return; } // otherwise, this is the KV cache of a Transformer-like model @@ -3250,9 +3778,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id uint32_t new_head = cache.size; for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.recurrent && (llama_seq_id) i != seq_id) { + cache.cells[i].tail = -1; + } if (!cache.cells[i].has_seq_id(seq_id)) { if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + cache.cells[i].src = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; } else { @@ -3281,9 +3813,12 @@ static void llama_kv_cache_seq_add( if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } } } return; @@ -3327,9 +3862,12 @@ static void llama_kv_cache_seq_div( if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } } } return; @@ -3361,7 +3899,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama } static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { - cache.do_defrag = true; + if (!cache.recurrent) { + cache.do_defrag = true; + } } static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { @@ -5011,6 +5551,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5235,6 +5776,23 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_EXAONE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5467,6 +6025,15 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -5840,6 +6407,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); @@ -6040,6 +6608,7 @@ static bool llm_load_tensors( const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -6097,7 +6666,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); @@ -6105,9 +6674,9 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); @@ -6431,7 +7000,7 @@ static bool llm_load_tensors( layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens - layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); @@ -7550,8 +8119,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); @@ -7562,6 +8131,78 @@ static bool llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; + case LLM_ARCH_NEMOTRON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_EXAONE: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7803,7 +8444,7 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_batch & batch, + const llama_ubatch & batch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; @@ -8237,9 +8878,10 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } @@ -8248,7 +8890,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -8352,12 +8994,180 @@ static struct ggml_tensor * llm_build_kv( return cur; } +static struct ggml_tensor * llm_build_copy_mask_state( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct ggml_tensor * s, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t n_state, + int32_t kv_size, + int32_t kv_head, + int32_t n_kv, + int32_t n_seqs) { + struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); + + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, state_copy); + + // clear states of sequences which are starting at the beginning of this batch + // FIXME: zero-out NANs? + states = ggml_mul(ctx, states, state_mask); + + // copy states which won't be changed further (between n_seqs and n_rs) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + + // the part of the states that will be used and modified + return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); +} + +// TODO: split +static struct ggml_tensor * llm_build_mamba( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = batch.n_seqs; + // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) + const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; + // Use the same RMS norm as the final layer norm + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + + // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms) { + dt = ggml_rms_norm(ctx, dt, norm_rms_eps); + B = ggml_rms_norm(ctx, B, norm_rms_eps); + C = ggml_rms_norm(ctx, C, norm_rms_eps); + } + + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + + // TODO: skip computing output earlier for unused tokens + + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; const llama_hparams & hparams; const llama_cparams & cparams; - const llama_batch & batch; + const llama_ubatch & batch; const llama_kv_cache & kv_self; const int64_t n_embd; @@ -8403,7 +9213,7 @@ struct llm_build_context { // TODO: consider making the entire interface noexcept llm_build_context( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, const llm_build_cb & cb, bool worst_case) : model (lctx.model), @@ -8510,29 +9320,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -8667,7 +9454,7 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; @@ -8680,13 +9467,6 @@ struct llm_build_context { return lctx.inp_s_mask; } - struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); - cb(lctx.inp_s_seq, "inp_s_seq", -1); - ggml_set_input(lctx.inp_s_seq); - return lctx.inp_s_seq; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -12016,125 +12796,31 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - const int64_t d_model = n_embd; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - struct ggml_tensor * cur; struct ggml_tensor * inpL; // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - // clear states of sequences which are starting at the beginning of this batch - { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); - } - - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur); - // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); - // conv - { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - - // extract x from x_conv - x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); - - // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); - - x = ggml_silu(ctx0, x); - } - - // ssm - { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x); - // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); - - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); - - struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - x = ggml_get_rows(ctx0, x, inp_out_ids); - y = ggml_get_rows(ctx0, y, inp_out_ids); - z = ggml_get_rows(ctx0, z, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); - - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } // residual @@ -13749,11 +14435,259 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_nemotron() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_exaone() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -13769,8 +14703,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const } static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -13785,26 +14719,9 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_s_copy(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, bool worst_case) { const auto & model = lctx.model; @@ -14004,6 +14921,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_jais(); } break; + case LLM_ARCH_NEMOTRON: + { + result = llm.build_nemotron(); + } break; + case LLM_ARCH_EXAONE: + { + result = llm.build_exaone(); + } break; default: GGML_ABORT("fatal error"); } @@ -14066,7 +14991,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } -static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // // set input data // @@ -14105,10 +15030,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.logits) { + } else if (batch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.logits[i]) { + if (batch.output[i]) { data[n_outputs++] = i; } } @@ -14132,8 +15057,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; float * data = nullptr; @@ -14153,32 +15080,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; - for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); - } else { - f = 0.0f; - } - } + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = batch.pos[s*n_seq_tokens + j]; - if (data) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; - } - - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } } @@ -14200,8 +15130,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -14209,27 +15141,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = batch.seq_id[s1][0]; - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < batch.n_seq_id[s0]; ++s) { + if (batch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(batch.pos[ti] - batch.pos[tj]); + } else { + f = 0.0f; + } + break; + } } - break; + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; } } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } } } } @@ -14237,7 +15177,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -14246,12 +15188,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; + sum[seq_id] += batch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -14262,14 +15206,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - data[seq_id*n_tokens + i] = div[seq_id]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } } } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -14277,20 +15226,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { uint32_t * data = (uint32_t *) lctx.inp_cls->data; memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { - data[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } } } } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -14301,15 +15256,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { std::vector last_pos(n_tokens, -1); std::vector last_row(n_tokens, -1); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); - if (pos >= last_pos[seq_id]) { - last_pos[seq_id] = pos; - last_row[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } } } @@ -14327,41 +15286,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched + // clear unused states for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) (kv_cell.src >= 0); - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; } } } - // For Mamba (and other recurrent architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Like with the KQ_mask, if a token in the batch has multiple sequences, - // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). - if (lctx.inp_s_seq) { - const int64_t n_tokens = batch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_seq->data; + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - for (int i = 0; i < n_kv; ++i) { - if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; - } else { - data[j*n_kv + i] = -1; - } + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; } } } @@ -14371,6 +15328,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; @@ -14406,6 +15364,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); + GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing float * data = (float *) lctx.inp_KQ_mask_cross->data; @@ -14499,11 +15458,49 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { return n_outputs_max; } +// make the outputs have the same order they had in the user-provided batch +static void llama_output_reorder(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx->logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } + } + if (ctx->embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} static void llama_graph_compute( - llama_context & lctx, - ggml_cgraph * gf, - int n_threads) { + llama_context & lctx, + ggml_cgraph * gf, + int n_threads, + ggml_threadpool * threadpool) { #ifdef GGML_USE_METAL if (ggml_backend_is_metal(lctx.backend_metal)) { ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); @@ -14512,6 +15509,7 @@ static void llama_graph_compute( if (lctx.backend_cpu != nullptr) { ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool); ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); } #ifdef GGML_USE_BLAS @@ -14571,15 +15569,11 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - // TODO: simplify or deprecate - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + lctx.embd_seq.clear(); + // count outputs if (batch_all.logits && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -14592,55 +15586,42 @@ static int llama_decode_internal( n_outputs = 1; } + lctx.sbatch.from_batch(batch_all, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); + // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); return -2; }; - // set output mappings - if (batch_all.logits) { - int32_t i_logits = 0; - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { - lctx.output_ids[i] = i_logits++; + while (lctx.sbatch.n_tokens > 0) { + llama_ubatch ubatch; + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(n_ubatch); } + } else { + ubatch = lctx.sbatch.split_simple(n_ubatch); } - } else { - for (uint32_t i = 0; i < n_outputs; ++i) { - lctx.output_ids[i] = i; - } - } - - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - llama_batch u_batch = { - /* .n_tokens = */ (int32_t) n_tokens, - /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - /* .all_pos_1 = */ batch_all.all_pos_1, - /* .all_seq_id = */ batch_all.all_seq_id, - }; + const uint32_t n_tokens = ubatch.n_tokens; // count the outputs in this u_batch { int32_t n_outputs_new = 0; - if (u_batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.logits[i] != 0; - } - } else if (n_outputs == n_tokens_all) { + if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { - // keep last output only - if (cur_token + n_tokens >= n_tokens_all) { - n_outputs_new = 1; + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); } } @@ -14649,34 +15630,10 @@ static int llama_decode_internal( } int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (u_batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; - } - - u_batch.pos = pos.data(); - } - - if (u_batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = u_batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - u_batch.n_seq_id = n_seq_id.data(); - u_batch.seq_id = seq_id_arr.data(); - } - // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); @@ -14687,7 +15644,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(kv_self, ubatch)) { return 1; } @@ -14706,7 +15663,7 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false); // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -14734,9 +15691,9 @@ static int llama_decode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); - llama_set_inputs(lctx, u_batch); + llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads); + llama_graph_compute(lctx, gf, n_threads, threadpool); // update the kv ring buffer { @@ -14792,12 +15749,11 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: { - // extract sequence embeddings + // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; - embd_seq_out.clear(); - for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = u_batch.seq_id[i][0]; + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } @@ -14814,6 +15770,25 @@ static int llama_decode_internal( n_outputs_prev += lctx.n_outputs; } + // set output mappings + { + bool sorted_output = true; + + GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs); + + for (size_t i = 0; i < n_outputs; ++i) { + size_t out_id = lctx.sbatch.out_ids[i]; + lctx.output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + if (sorted_output) { + lctx.sbatch.out_ids.clear(); + } + } + // set to total number of outputs in the batch, for use in llama_get_logits_ith lctx.n_outputs = n_outputs; @@ -14878,11 +15853,9 @@ static int llama_encode_internal( const int64_t n_embd = hparams.n_embd; - // TODO: simplify or deprecate - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; + lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); // reserve output buffer if (llama_output_reserve(lctx, n_tokens) < n_tokens) { @@ -14897,39 +15870,15 @@ static int llama_encode_internal( lctx.inp_embd_enc = NULL; lctx.n_outputs = n_tokens; - const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = batch.all_pos_0 + i*batch.all_pos_1; - } - - batch.pos = pos.data(); - } - - if (batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - batch.n_seq_id = n_seq_id.data(); - batch.seq_id = seq_id_arr.data(); - } - ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, batch, false); + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false); // the output embeddings after the final encoder normalization struct ggml_tensor * embd = nullptr; @@ -14953,9 +15902,9 @@ static int llama_encode_internal( ggml_backend_sched_alloc_graph(lctx.sched, gf); - llama_set_inputs(lctx, batch); + llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads); + llama_graph_compute(lctx, gf, n_threads, threadpool); // extract embeddings if (embd) { @@ -14967,12 +15916,13 @@ static int llama_encode_internal( float * embd_out = lctx.embd_enc.data(); ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits // remember the sequence ids used during the encoding - needed for cross attention later lctx.seq_ids_enc.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < batch.n_seq_id[i]; s++) { - llama_seq_id seq_id = batch.seq_id[i][s]; + for (int s = 0; s < ubatch.n_seq_id[i]; s++) { + llama_seq_id seq_id = ubatch.seq_id[i][s]; lctx.seq_ids_enc[i].insert(seq_id); } } @@ -14997,8 +15947,10 @@ static int llama_encode_internal( auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } @@ -15234,7 +16186,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids); - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); + llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); #endif //const int64_t t_end = ggml_time_us(); @@ -15260,7 +16212,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { llama_set_k_shift(lctx); - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); + llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); need_reserve = true; } @@ -15276,32 +16228,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed if (lctx.kv_self.do_defrag) { llama_kv_cache_defrag_internal(lctx); @@ -15315,10 +16241,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph - int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); - int n_past = lctx.cparams.n_ctx - n_tokens; + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(lctx.sched); @@ -15710,6 +16637,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); ++qs.n_fallback; } @@ -15901,7 +16831,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // TODO: avoid hardcoded tensor names - use the TN_* constants if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos) { + name.find("attn_qkv.weight") != std::string::npos || + name.find("attn_kv_b.weight")!= std::string::npos) { ++qs.n_attention_wv; } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; @@ -15911,12 +16842,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks - // - // - qs.n_attention_wv == 0 for Mamba models - // - qs.n_attention_wv == model.hparams.n_layer for Transformer models - // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models - // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + { + const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); + // attention layers have a non-zero number of kv heads + int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + if (llama_model_has_encoder(&model)) { + n_attn_layer *= 3; + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } size_t total_size_org = 0; size_t total_size_new = 0; @@ -16038,8 +16972,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // do not quantize Mamba's small yet 2D weights // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("ssm_x.weight") == std::string::npos; - quantize &= name.find("ssm_dt.weight") == std::string::npos; // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; @@ -16528,6 +17460,19 @@ void llama_numa_init(enum ggml_numa_strategy numa) { } } +void llama_attach_threadpool( + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + ctx->threadpool = threadpool; + ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; +} + +void llama_detach_threadpool(struct llama_context * ctx) { + ctx->threadpool = nullptr; + ctx->threadpool_batch = nullptr; +} + void llama_backend_free(void) { ggml_quantize_free(); } @@ -16613,12 +17558,6 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.attn_soft_cap) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); - params.flash_attn = false; - } - - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; @@ -16727,7 +17666,7 @@ struct llama_context * llama_new_context_with_model( ggml_type type_v = params.type_v; // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { + if (llama_model_is_recurrent(model)) { // Mamba needs at least as many KV cells as there are sequences kept at any time kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states @@ -16959,10 +17898,11 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); - int n_past = cparams.n_ctx - n_tokens; + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph if (!ggml_backend_sched_reserve(ctx->sched, gf)) { @@ -17074,6 +18014,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: + case LLM_ARCH_NEMOTRON: + case LLM_ARCH_EXAONE: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here @@ -17200,6 +18142,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { return model->hparams.dec_start_token_id; } +bool llama_model_is_recurrent(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + default: return false; + } +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -17521,7 +18470,9 @@ struct llama_data_write { write_string(rng_str); } - void write_output_ids(const struct llama_context * ctx) { + void write_output_ids(struct llama_context * ctx) { + llama_output_reorder(ctx); + const uint32_t n_outputs = ctx->n_outputs; std::vector output_pos; @@ -17809,8 +18760,11 @@ struct llama_data_read { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - llama_batch batch = llama_batch_init(cell_count, 0, 1); + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -17824,11 +18778,10 @@ struct llama_data_read { } batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; if (!llama_kv_cache_find_slot(kv_self, batch)) { - llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } @@ -17840,9 +18793,6 @@ struct llama_data_read { GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - - // Cleanup - llama_batch_free(batch); } else { // whole KV cache restore @@ -17874,6 +18824,15 @@ struct llama_data_read { } cell.seq_id.insert(seq_id); + + if (kv_self.recurrent) { + int32_t & tail = kv_self.cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } } } @@ -17881,6 +18840,14 @@ struct llama_data_read { kv_self.used = cell_count; } + if (kv_self.recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = kv_self.head + i; + // make sure the recurrent states will keep their restored state + kv_self.cells[cell_id].src = cell_id; + } + } + return true; } @@ -18422,16 +19389,16 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa } } -void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { +void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; } -uint32_t llama_n_threads(struct llama_context * ctx) { +int32_t llama_n_threads(struct llama_context * ctx) { return ctx->cparams.n_threads; } -uint32_t llama_n_threads_batch(struct llama_context * ctx) { +int32_t llama_n_threads_batch(struct llama_context * ctx) { return ctx->cparams.n_threads_batch; } @@ -18468,7 +19435,18 @@ struct llama_batch llama_batch_get_one( } struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ 0, + /*all_pos_1 =*/ 0, + /*all_seq_id =*/ 0, + }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); @@ -18554,6 +19532,10 @@ void llama_synchronize(struct llama_context * ctx) { float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(ctx); + return ctx->logits; } @@ -18598,6 +19580,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(ctx); + return ctx->embd; } @@ -18699,11 +19685,11 @@ llama_token llama_token_pad(const struct llama_model * model) { return llama_token_pad_impl(model->vocab); } -int32_t llama_add_bos_token(const struct llama_model * model) { +bool llama_add_bos_token(const struct llama_model * model) { return llama_add_bos_token_impl(model->vocab); } -int32_t llama_add_eos_token(const struct llama_model * model) { +bool llama_add_eos_token(const struct llama_model * model) { return llama_add_eos_token_impl(model->vocab); } @@ -19004,6 +19990,22 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } } else { // template not supported return -1; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2f4117a62..c832bc956 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -949,6 +949,58 @@ struct test_rms_norm : public test_case { } }; +// GGML_OP_SSM_CONV +struct test_ssm_conv : public test_case { + const ggml_type type; + const std::array ne_a; + const std::array ne_b; + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, ne_b); + } + + test_ssm_conv(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {10, 10, 10, 1}, + std::array ne_b = {3, 3, 1, 1}) + : type(type), ne_a(ne_a), ne_b(ne_b) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * out = ggml_ssm_conv(ctx, a, b); + return out; + } +}; + +// GGML_OP_SSM_SCAN +struct test_ssm_scan : public test_case { + const ggml_type type; + + const int64_t d_state; + const int64_t d_inner; + const int64_t n_seq_tokens; + const int64_t n_seqs; + + std::string vars() override { + return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + } + + test_ssm_scan(ggml_type type = GGML_TYPE_F32, + int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); + ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); + ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + return out; + } +}; + // GGML_OP_MUL_MAT struct test_mul_mat : public test_case { const ggml_type type_a; @@ -1108,6 +1160,58 @@ struct test_sqrt : public test_case { } }; +// GGML_OP_SIN +struct test_sin : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sin(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sin(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -100.0f, 100.0f); + } + } +}; + +// GGML_OP_COS +struct test_cos : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cos(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_cos(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -100.0f, 100.0f); + } + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -1652,19 +1756,20 @@ struct test_flash_attn_ext : public test_case { const bool mask; // use mask const float max_bias; // ALiBi + const float logit_softcap; // Gemma 2 const ggml_type type_KV; std::string vars() override { - return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); + return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); @@ -1673,7 +1778,28 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); + return out; + } +}; + +// GGML_OP_CROSS_ENTROPY_LOSS +struct test_cross_entropy_loss : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels); return out; } }; @@ -2145,6 +2271,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + // sycl backend will limit task global_range < MAX_INT + // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) + // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) + // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) + // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true)); + // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_conv_transpose_1d()); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); @@ -2232,6 +2365,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); + + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -2287,6 +2426,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); + // sycl backend will limit task global_range < MAX_INT + // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion) + // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.) + // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) + // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1})); + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {4, 8}) { @@ -2321,6 +2466,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_sqr()); test_cases.emplace_back(new test_sqrt()); + test_cases.emplace_back(new test_sin()); + test_cases.emplace_back(new test_cos()); test_cases.emplace_back(new test_clamp()); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -2424,11 +2571,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (bool mask : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { if (!mask && max_bias > 0.0f) continue; - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + for (float logit_softcap : {0.0f, 10.0f}) { + if (hs != 128 && logit_softcap != 0.0f) continue; + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); + } } } } @@ -2437,6 +2587,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_cross_entropy_loss()); + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1)); @@ -2470,7 +2622,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } GGML_ABORT("fatal error"); - return false; } static void usage(char ** argv) { diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index a35327645..1834c11d8 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1,10 +1,14 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #include "ggml.h" +#include #include +#include #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -217,7 +221,8 @@ static bool check_gradient( int nargs, float eps, float max_error_abs, - float max_error_rel) { + float max_error_rel, + std::vector expected_vals) { static int n_threads = -1; if (n_threads < 0) { @@ -248,9 +253,10 @@ static bool check_gradient( // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot"); for (int i = 0; i < nargs; ++i) { + bool all_g0_bad = true; const int nelements = ggml_nelements(x[i]); for (int k = 0; k < nelements; ++k) { - // compute gradient using finite differences + // Calculate gradient numerically: const float x0 = ggml_get_f32_1d(x[i], k); const float xm = x0 - eps; const float xp = x0 + eps; @@ -267,6 +273,28 @@ static bool check_gradient( const double f1 = ggml_get_f32_1d(f, 0); const double g0 = (f0 - f1)/(2.0*(double) eps); + // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU). + // In such cases, provide a vector of expected values and skip the comparison for failed calculations. + if (!expected_vals.empty()) { + bool matches_any = false; + for (const double & ev : expected_vals) { + const double error_abs = std::fabs(g0 - ev); + if (error_abs > max_error_abs) { + continue; + } + const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0; + if (error_rel > max_error_rel) { + continue; + } + matches_any = true; + break; + } + if (!matches_any) { + continue; + } + } + all_g0_bad = false; + ggml_set_f32_1d(x[i], k, x0); // compute gradient using backward graph @@ -278,7 +306,7 @@ static bool check_gradient( const double g1 = ggml_get_f32_1d(x[i]->grad, k); const double error_abs = fabs(g0 - g1); - const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0; + const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0; if (error_abs > max_error_abs || error_rel > max_error_rel) { printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", @@ -287,6 +315,10 @@ static bool check_gradient( return false; } } + if (all_g0_bad) { + printf("%s: numerical calculation of the gradient failed for all values\n", op_name); + return false; + } } return true; @@ -404,7 +436,7 @@ int main(int argc, const char ** argv) { seed_iter = rand(); unsigned seed = rand(); - printf("test-grad0: iter:%d/%d\n", iter, niter); + printf("test-grad0: iter:%d/%d\n", (iter+1), niter); struct ggml_context * ctx0 = ggml_init(params); get_random_dims(ne, 4); @@ -424,7 +456,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f); + check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {}); } } @@ -441,7 +473,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f); + check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {}); } } @@ -458,7 +490,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1])); - check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -475,7 +507,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1])); - check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -492,7 +524,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1])); - check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f); + check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {}); } } @@ -509,7 +541,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0])); - check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -526,7 +558,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0])); - check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f); + check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {}); } } @@ -543,7 +575,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0])); - check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f); + check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {}); } } @@ -560,7 +592,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, x[0]); - check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -578,7 +610,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0]))); - check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } @@ -596,7 +628,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0])); - check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -614,7 +646,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0])); - check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -637,7 +669,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); - check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } @@ -660,25 +692,25 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0])))); - check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); + check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); } } - // abs (finite differences do not work) - //{ - // const int nargs = 1; + // abs + { + const int nargs = 1; - // for (int ndims = 1; ndims <= 2; ++ndims) { - // for (int i = 0; i < nargs; ++i) { - // x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - // ggml_set_param(ctx0, x[i]); - // } + for (int ndims = 1; ndims <= 4; ++ndims) { + for (int i = 0; i < nargs; ++i) { + x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); + ggml_set_param(ctx0, x[i]); + } - // struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); - // check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f); - // } - //} + check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0}); + } + } // sgn { @@ -693,7 +725,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0])); - check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); } } @@ -710,7 +742,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0])); - check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -727,7 +759,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0])); - check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); } } @@ -745,7 +777,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0])); - check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -776,7 +808,7 @@ int main(int argc, const char ** argv) { GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims); - check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); if (ndims == 2) { // check_mat_mul does not support ndims > 2 check_mat_mul(m, x[1], x[0]); @@ -800,7 +832,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0])); - check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -817,7 +849,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0])); - check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0}); } } @@ -835,7 +867,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0])); - check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f); + check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); } } @@ -854,9 +886,9 @@ int main(int argc, const char ** argv) { #ifdef GGML_SILU_FP16 // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds. - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY); + check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {}); #else - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); #endif } } @@ -874,7 +906,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); - check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); + check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {}); } } @@ -892,7 +924,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s)); - check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -910,7 +942,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -928,7 +960,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); + check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); } } @@ -952,7 +984,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -976,7 +1008,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1004,7 +1036,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1037,7 +1069,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1072,7 +1104,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1109,7 +1141,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1137,7 +1169,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset)); - check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1170,7 +1202,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset)); - check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1194,7 +1226,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset)); - check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1225,7 +1257,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset)); - check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1257,7 +1289,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset)); - check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1291,7 +1323,7 @@ int main(int argc, const char ** argv) { // sum requires contiguous tensor rows struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3))); - check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1319,7 +1351,7 @@ int main(int argc, const char ** argv) { // sum requires contiguous tensor rows struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0]))); - check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1337,7 +1369,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1])); - check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // diag_mask_inf @@ -1353,7 +1385,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past)); - check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // diag_mask_zero @@ -1369,7 +1401,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past)); - check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } // softmax @@ -1395,7 +1427,7 @@ int main(int argc, const char ** argv) { 1.0f - eps), ggml_new_f32(ctx0, eps)))); - check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY); + check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {}); // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf. // this may result in different gradients too finite differences. // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause. @@ -1412,7 +1444,7 @@ int main(int argc, const char ** argv) { get_random_dims(ne2, 4); for (int ndims = 1; ndims <= 4; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f); + x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); // the second argument to cross_entropy_loss must sum up to 1 for each row int nr = ggml_nrows(x[1]); @@ -1430,7 +1462,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY); + check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); } } @@ -1468,7 +1500,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); + check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); } } } @@ -1508,12 +1540,93 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); + check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); } } } } + // im2col f32 + { + srand(seed); + const int nargs = 1; + const int ndims = 4; + + for (const bool is_2D : {false, true}) { + int64_t ne0[ndims]; + int64_t ne1[ndims]; + get_random_dims(ne0, ndims); + get_random_dims(ne1, ndims); + + // // Ensure that the output is not zero-sized: + ne1[0] += 8; + ne1[1] += 8; + + if (is_2D) { + ne1[2] = ne0[2]; + } else { + ne1[1] = ne0[1]; + ne0[3] = 1; + ne1[3] = 1; + } + + // The order of arguments is swapped because the first tensor is only used for its shape. + x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f); + x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f); + + ggml_set_param(ctx0, x[0]); + + const int s0 = 1 + irand(2); + const int s1 = is_2D ? 1 + irand(2) : 0; + const int p0 = 0 + irand(2); + const int p1 = is_2D ? 0 + irand(2) : 0; + const int d0 = 1 + irand(2); + const int d1 = is_2D ? 1 + irand(2) : 0; + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32)); + + GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1); + check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); + } + } + + // pool_2d f32 + { + srand(seed); + const int nargs = 1; + const int ndims = 4; + + for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { + int64_t ne0[ndims]; + get_random_dims(ne0, ndims); + + ne0[0] += 8; + ne0[1] += 8; + + x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f); + + ggml_set_param(ctx0, x[0]); + + const int k0 = 2 + irand(2); + const int k1 = 2 + irand(2); + const int s0 = 2 + irand(2); + const int s1 = 2 + irand(2); + const int p0 = 0 + irand(2); + const int p1 = 0 + irand(2); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1)); + + GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n", + op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1); + std::vector expected_vals; + if (op == GGML_OP_POOL_MAX) { + expected_vals.push_back(0.0); + expected_vals.push_back(1.0); + } + check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals); + } + } + // flash_attn f32 // TODO: adapt to ggml_flash_attn_ext() changes //{ @@ -1553,7 +1666,7 @@ int main(int argc, const char ** argv) { // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); + // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {}); // } // } // } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 68f971bfe..9c4e7d18e 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -503,7 +503,7 @@ static void test_special_chars() { "aaaaabcccc", "aaaabccc", "aaaabccccc", - "🔵🟠✅❌abc❌✅🟠🔵" + "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" } ); diff --git a/tests/test-lora-conversion-inference.sh b/tests/test-lora-conversion-inference.sh new file mode 100755 index 000000000..fe90ce0d1 --- /dev/null +++ b/tests/test-lora-conversion-inference.sh @@ -0,0 +1,139 @@ +#!/bin/bash +set -e + +# Array of models to iterate over +declare -a params=( + "Gemma2ForCausalLM 64" + "LlamaForCausalLM 64" + "Phi3ForCausalLM 64" +) + +MODELS_REPO=lora-tests +MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO + +# Clone the Hugging Face repository if the directory does not exist +if [ ! -d "$MODELS_REPO" ]; then + echo "Cloning the Hugging Face repository..." + git clone $MODELS_REPO_URL --depth 1 +else + echo "Repository already exists. Skipping clone." +fi + +# Array to store results to print +results=() + +trim_leading_whitespace() { + local input_string="$1" + echo "${input_string#"${input_string%%[![:space:]]*}"}" +} + +extract_starting_substring() { + local reference_string="$1" + local target_string="$2" + + local target_length=${#target_string} + echo "${reference_string:0:$target_length}" +} + +get_first_word() { + local input_string="$1" + read -r first_word _ <<< "$input_string" + echo "$first_word" +} + +# Load the expected strings +EXPECTED_BASE_FULL=$(cat $MODELS_REPO/data/pale_blue_dot.txt) +EXPECTED_LORA_FULL=$(cat $MODELS_REPO/data/bohemian_rhapsody.txt) +EXPECTED_BASE_FIRST_WORD=$(get_first_word "$EXPECTED_BASE_FULL") +EXPECTED_LORA_FIRST_WORD=$(get_first_word "$EXPECTED_LORA_FULL") + +run_conversion_and_inference_lora() { + local model_name=$1 + local hidden_size=$2 + + echo -e "\n\n-------- RUNNING TEST FOR MODEL $model_name --------\n\n" + + # Convert safetensors to gguf + echo "Running convert_hf_to_gguf.py for $model_name with hidden_size $hidden_size..." + python convert_hf_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \ + --outfile $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + --outtype f32 + + echo -e "\n\n---------------------------\n\n" + echo "Running convert_lora_to_gguf.py for $model_name with hidden_size $hidden_size..." + python3 convert_lora_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora \ + --base $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \ + --outtype f32 + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-export-lora with lora for $model_name with hidden_size $hidden_size..." + ./llama-export-lora \ + -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + -o $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \ + --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf + + # Run inference + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..." + OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + -p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0) + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..." + OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ + --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \ + -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) + + echo -e "\n\n---------------------------\n\n" + echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..." + OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \ + -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) + + # Remove any initial white space + OUTPUT_BASE=$(trim_leading_whitespace "$OUTPUT_BASE") + OUTPUT_LORA_HOT=$(trim_leading_whitespace "$OUTPUT_LORA_HOT") + OUTPUT_LORA_MERGED=$(trim_leading_whitespace "$OUTPUT_LORA_MERGED") + # Extract the corresponding substring from full string + EXPECTED_BASE=$(extract_starting_substring "$EXPECTED_BASE_FULL" "$OUTPUT_BASE") + EXPECTED_LORA=$(extract_starting_substring "$EXPECTED_LORA_FULL" "$OUTPUT_LORA_HOT") + + # Assert output equals the expected output + if [[ "$OUTPUT_BASE" != "$EXPECTED_BASE" ]]; then + echo "Error: $model_name OUTPUT_BASE does not start with the expected string." + echo -e "Out=$OUTPUT_BASE\n\nExp=$EXPECTED_BASE" + exit 1 + fi + if [[ "$OUTPUT_LORA_HOT" != "$EXPECTED_LORA" ]]; then + echo "Error: $model_name OUTPUT_LORA_HOT does not start with the expected string." + echo -e "Out=$OUTPUT_LORA_HOT\n\nExp=$EXPECTED_LORA" + exit 1 + fi + if [[ "$OUTPUT_LORA_MERGED" != "$EXPECTED_LORA" ]]; then + echo "Error: $model_name OUTPUT_LORA_MERGED does not start with the expected string." + echo -e "Out=$OUTPUT_LORA_MERGED\n\nExp=$EXPECTED_LORA" + exit 1 + fi + + # Store the results + results+=(" + \n\033[1mResults for $model_name with hidden_size $hidden_size:\033[0m + \n\033[32m • Base:\n$OUTPUT_BASE + \n\033[34m • Lora hot:\n$OUTPUT_LORA_HOT + \n\033[36m • Lora merged:\n$OUTPUT_LORA_MERGED + \n \033[0m + ") + + echo "All tests passed for $model_name with hidden_size $hidden_size!" +} + +# Run test for each model +for param in "${params[@]}"; do + run_conversion_and_inference_lora $param +done + +# Print results +echo -e "\n\n---------------------------\n\n" +echo -e "\n\033[1mSummary of All Results:\033[0m" +for result in "${results[@]}"; do + echo -e "$result" +done diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp index 8159e276a..246bb227d 100644 --- a/tests/test-rope.cpp +++ b/tests/test-rope.cpp @@ -113,7 +113,7 @@ static struct ggml_tensor * get_random_tensor_f32( } static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); if (plan.work_size > 0) { buf.resize(plan.work_size); diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index de858bd3b..6c2a5db9a 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -166,12 +166,12 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; - case 'f': GGML_ABORT("tail_free test not implemented"); break; - case 'y': GGML_ABORT("typical test not implemented"); break; + case 'f': GGML_ABORT("tail_free test not implemented"); + case 'y': GGML_ABORT("typical test not implemented"); case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; - case 't': GGML_ABORT("temperature test not implemented"); break; - default : GGML_ABORT("Unknown sampler"); break; + case 't': GGML_ABORT("temperature test not implemented"); + default : GGML_ABORT("Unknown sampler"); } llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests