Merge branch 'ggerganov:master' into cuda-releases

This commit is contained in:
Olivier Chafik 2025-01-20 10:24:54 +00:00 committed by GitHub
commit 67075cc8bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 6216 additions and 1741 deletions

View file

@ -87,6 +87,7 @@ jobs:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
- name: Upload artifacts - name: Upload artifacts
@ -149,6 +150,7 @@ jobs:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
- name: Upload artifacts - name: Upload artifacts
@ -217,6 +219,7 @@ jobs:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: | run: |
cp LICENSE ./build/bin/ cp LICENSE ./build/bin/
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/* zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
- name: Upload artifacts - name: Upload artifacts
@ -234,7 +237,7 @@ jobs:
strategy: strategy:
matrix: matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED] sanitizer: [ADDRESS, THREAD, UNDEFINED]
build_type: [Debug, Release] build_type: [Debug]
steps: steps:
- name: Clone - name: Clone
@ -796,6 +799,7 @@ jobs:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: | run: |
Copy-Item LICENSE .\build\bin\Release\llama.cpp.txt Copy-Item LICENSE .\build\bin\Release\llama.cpp.txt
Copy-Item .\examples\run\linenoise.cpp\LICENSE .\build\bin\Release\linenoise.cpp.txt
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\* 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
- name: Upload artifacts - name: Upload artifacts

View file

@ -112,9 +112,9 @@ jobs:
-DGGML_OPENMP=OFF ; -DGGML_OPENMP=OFF ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build - name: Build (sanitizers)
id: cmake_build id: cmake_build_sanitizers
if: ${{ matrix.sanitizer != 'THREAD' }} if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
run: | run: |
cmake -B build \ cmake -B build \
-DGGML_NATIVE=OFF \ -DGGML_NATIVE=OFF \
@ -124,12 +124,31 @@ jobs:
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build
if: ${{ matrix.sanitizer == '' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_BUILD_SERVER=ON \
-DLLAMA_CURL=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Tests - name: Tests
id: server_integration_tests id: server_integration_tests
if: ${{ matrix.sanitizer == '' }}
run: | run: |
cd examples/server/tests cd examples/server/tests
./tests.sh ./tests.sh
- name: Tests (sanitizers)
id: server_integration_tests_sanitizers
if: ${{ matrix.sanitizer != '' }}
run: |
cd examples/server/tests
LLAMA_SANITIZE=1 ./tests.sh
- name: Slow tests - name: Slow tests
id: server_integration_tests_slow id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}

1
.gitignore vendored
View file

@ -18,6 +18,7 @@
*.metallib *.metallib
*.o *.o
*.so *.so
*.swp
*.tmp *.tmp
# IDE / OS # IDE / OS

View file

@ -83,11 +83,8 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
# override ggml options # override ggml options
set(GGML_SANITIZE_THREAD ${LLAMA_SANITIZE_THREAD}) set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
set(GGML_SANITIZE_ADDRESS ${LLAMA_SANITIZE_ADDRESS}) set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
set(GGML_SANITIZE_UNDEFINED ${LLAMA_SANITIZE_UNDEFINED})
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
# change the default for these ggml options # change the default for these ggml options
if (NOT DEFINED GGML_LLAMAFILE) if (NOT DEFINED GGML_LLAMAFILE)
@ -117,16 +114,62 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
llama_option_depr(WARNING LLAMA_CANN GGML_CANN) llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD)
message(STATUS "Using -fsanitize=thread")
add_compile_options(-fsanitize=thread)
link_libraries (-fsanitize=thread)
endif()
if (LLAMA_SANITIZE_ADDRESS)
message(STATUS "Using -fsanitize=address")
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
link_libraries (-fsanitize=address)
endif()
if (LLAMA_SANITIZE_UNDEFINED)
message(STATUS "Using -fsanitize=undefined")
add_compile_options(-fsanitize=undefined)
link_libraries (-fsanitize=undefined)
endif()
endif()
# #
# build the library # 3rd-party
# #
if (NOT TARGET ggml) if (NOT TARGET ggml)
add_subdirectory(ggml) add_subdirectory(ggml)
# ... otherwise assume ggml is added by a parent CMakeLists.txt # ... otherwise assume ggml is added by a parent CMakeLists.txt
endif() endif()
#
# build the library
#
add_subdirectory(src) add_subdirectory(src)
#
# utils, programs, examples and tests
#
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest)
add_subdirectory(tests)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
add_subdirectory(examples)
add_subdirectory(pocs)
endif()
# #
# install # install
# #
@ -200,21 +243,3 @@ configure_file(cmake/llama.pc.in
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
DESTINATION lib/pkgconfig) DESTINATION lib/pkgconfig)
#
# utils, programs, examples and tests
#
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest)
add_subdirectory(tests)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
add_subdirectory(examples)
add_subdirectory(pocs)
endif()

View file

@ -204,6 +204,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs - [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs
- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly - [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server - [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale
</details> </details>

View file

@ -326,17 +326,17 @@ function gg_run_open_llama_7b_v2 {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@ -460,17 +460,17 @@ function gg_run_pythia_1_4b {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-cli -no-cnv --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@ -591,17 +591,17 @@ function gg_run_pythia_2_8b {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log

View file

@ -376,6 +376,30 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
return devices; return devices;
} }
static void add_rpc_devices(std::string servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
}
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
}
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
}
}
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
auto ctx_arg = common_params_parser_init(params, ex, print_usage); auto ctx_arg = common_params_parser_init(params, ex, print_usage);
const common_params params_org = ctx_arg.params; // the example can modify the default params const common_params params_org = ctx_arg.params; // the example can modify the default params
@ -1385,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--rpc"}, "SERVERS", {"--rpc"}, "SERVERS",
"comma separated list of RPC servers", "comma separated list of RPC servers",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.rpc_servers = value; add_rpc_devices(value);
GGML_UNUSED(params);
} }
).set_env("LLAMA_ARG_RPC")); ).set_env("LLAMA_ARG_RPC"));
} }
@ -2229,6 +2254,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.vocoder.model = value; params.vocoder.model = value;
} }
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--tts-use-guide-tokens"},
"Use guide tokens to improve TTS word recall",
[](common_params & params) {
params.vocoder.use_guide_tokens = true;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
// model-specific // model-specific
add_opt(common_arg( add_opt(common_arg(

View file

@ -1043,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (params.n_gpu_layers != -1) { if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers; mparams.n_gpu_layers = params.n_gpu_layers;
} }
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu; mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode; mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;

View file

@ -184,6 +184,8 @@ struct common_params_vocoder {
std::string model = ""; // model path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT std::string model_url = ""; // model url to download // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
}; };
struct common_params { struct common_params {
@ -246,7 +248,6 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
std::vector<std::string> in_files; // all input files std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

View file

@ -2882,6 +2882,66 @@ class InternLM2Model(Model):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@Model.register("InternLM3ForCausalLM")
class InternLM3Model(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def set_vocab(self):
tokens, scores, toktypes = self._create_vocab_sentencepiece()
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
if "added_tokens_decoder" in tokenizer_config_json:
for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items():
if token_data.get("special"):
token_id = int(token_id)
token = token_data["content"]
special_vocab._set_special_token(token, token_id)
# update eos token
if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids:
special_vocab.special_token_ids["eos"] = token_id
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
@Model.register("BertModel", "BertForMaskedLM", "CamembertModel") @Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
class BertModel(Model): class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT model_arch = gguf.MODEL_ARCH.BERT

View file

@ -41,7 +41,7 @@ echo PASS
echo echo
# 2b. Test the sharded model is loading properly # 2b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
echo PASS echo PASS
echo echo
@ -51,7 +51,7 @@ echo PASS
echo echo
# 3b. Test the merged model is loading properly # 3b. Test the merged model is loading properly
$MAIN --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
echo PASS echo PASS
echo echo
@ -61,7 +61,7 @@ echo PASS
echo echo
# 4b. Test the sharded model is loading properly # 4b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
echo PASS echo PASS
echo echo
@ -71,7 +71,7 @@ echo
#echo #echo
# 5b. Test the merged model is loading properly # 5b. Test the merged model is loading properly
#$MAIN --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32 #$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
#echo PASS #echo PASS
#echo #echo
@ -81,7 +81,7 @@ echo PASS
echo echo
# 6b. Test the sharded model is loading properly # 6b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
echo PASS echo PASS
echo echo

View file

@ -683,7 +683,7 @@ struct cmd_params_instance {
bool cpu_strict; bool cpu_strict;
int poll; int poll;
int n_gpu_layers; int n_gpu_layers;
std::string rpc_servers; std::string rpc_servers_str;
llama_split_mode split_mode; llama_split_mode split_mode;
int main_gpu; int main_gpu;
bool no_kv_offload; bool no_kv_offload;
@ -696,8 +696,37 @@ struct cmd_params_instance {
llama_model_params mparams = llama_model_default_params(); llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = n_gpu_layers; mparams.n_gpu_layers = n_gpu_layers;
if (!rpc_servers.empty()) { if (!rpc_servers_str.empty()) {
mparams.rpc_servers = rpc_servers.c_str(); auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
// add RPC devices
if (!rpc_servers.empty()) {
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
exit(1);
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
exit(1);
}
static std::vector<ggml_backend_dev_t> devices;
devices.clear();
for (const std::string & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
devices.push_back(dev);
} else {
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
exit(1);
}
}
devices.push_back(nullptr);
mparams.devices = devices.data();
}
} }
mparams.split_mode = split_mode; mparams.split_mode = split_mode;
mparams.main_gpu = main_gpu; mparams.main_gpu = main_gpu;
@ -708,7 +737,7 @@ struct cmd_params_instance {
} }
bool equal_mparams(const cmd_params_instance & other) const { bool equal_mparams(const cmd_params_instance & other) const {
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers && return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
tensor_split == other.tensor_split; tensor_split == other.tensor_split;
} }

View file

@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
jlong context_pointer, jlong context_pointer,
jlong batch_pointer, jlong batch_pointer,
jstring jtext, jstring jtext,
jboolean format_chat,
jint n_len jint n_len
) { ) {
@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto tokens_list = common_tokenize(context, text, 1); bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context); auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
} }
for (auto id : tokens_list) { for (auto id : tokens_list) {
LOGi("%s", common_token_to_piece(context, id).c_str()); LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
} }
common_batch_clear(*batch); common_batch_clear(*batch);

View file

@ -65,6 +65,7 @@ class LLamaAndroid {
context: Long, context: Long,
batch: Long, batch: Long,
text: String, text: String,
formatChat: Boolean,
nLen: Int nLen: Int
): Int ): Int
@ -115,10 +116,10 @@ class LLamaAndroid {
} }
} }
fun send(message: String): Flow<String> = flow { fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) { when (val state = threadLocalState.get()) {
is State.Loaded -> { is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) { while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur) val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) { if (str == null) {

View file

@ -47,7 +47,7 @@ echo PASS
echo echo
# 3a. Test the requanted model is loading properly # 3a. Test the requanted model is loading properly
$MAIN --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32
echo PASS echo PASS
echo echo
@ -57,7 +57,7 @@ echo PASS
echo echo
# 4b. Test the requanted model is loading properly # 4b. Test the requanted model is loading properly
$MAIN --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32 $MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32
echo PASS echo PASS
echo echo

View file

@ -1,5 +1,5 @@
set(TARGET llama-run) set(TARGET llama-run)
add_executable(${TARGET} run.cpp) add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_features(${TARGET} PRIVATE cxx_std_17)

View file

@ -0,0 +1,26 @@
Copyright (c) 2010-2014, Salvatore Sanfilippo <antirez at gmail dot com>
Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
/* linenoise.h -- VERSION 1.0
*
* Guerrilla line editing library against the idea that a line editing lib
* needs to be 20,000 lines of C++ code.
*
* See linenoise.cpp for more information.
*
* ------------------------------------------------------------------------
*
* Copyright (c) 2010-2023, Salvatore Sanfilippo <antirez at gmail dot com>
* Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
* Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef __LINENOISE_H
#define __LINENOISE_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h> /* For size_t. */
extern const char *linenoiseEditMore;
/* The linenoiseState structure represents the state during line editing.
* We pass this state to functions implementing specific editing
* functionalities. */
struct linenoiseState {
int in_completion; /* The user pressed TAB and we are now in completion
* mode, so input is handled by completeLine(). */
size_t completion_idx; /* Index of next completion to propose. */
int ifd; /* Terminal stdin file descriptor. */
int ofd; /* Terminal stdout file descriptor. */
char *buf; /* Edited line buffer. */
size_t buflen; /* Edited line buffer size. */
const char *prompt; /* Prompt to display. */
size_t plen; /* Prompt length. */
size_t pos; /* Current cursor position. */
size_t oldpos; /* Previous refresh cursor position. */
size_t len; /* Current edited line length. */
size_t cols; /* Number of columns in terminal. */
size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */
int history_index; /* The history index we are currently editing. */
};
typedef struct linenoiseCompletions {
size_t len;
char **cvec;
} linenoiseCompletions;
/* Non blocking API. */
int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
const char *linenoiseEditFeed(struct linenoiseState *l);
void linenoiseEditStop(struct linenoiseState *l);
void linenoiseHide(struct linenoiseState *l);
void linenoiseShow(struct linenoiseState *l);
/* Blocking API. */
const char *linenoise(const char *prompt);
void linenoiseFree(void *ptr);
/* Completion API. */
typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *);
typedef const char*(linenoiseHintsCallback)(const char *, int *color, int *bold);
typedef void(linenoiseFreeHintsCallback)(const char *);
void linenoiseSetCompletionCallback(linenoiseCompletionCallback *);
void linenoiseSetHintsCallback(linenoiseHintsCallback *);
void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *);
void linenoiseAddCompletion(linenoiseCompletions *, const char *);
/* History API. */
int linenoiseHistoryAdd(const char *line);
int linenoiseHistorySetMaxLen(int len);
int linenoiseHistorySave(const char *filename);
int linenoiseHistoryLoad(const char *filename);
/* Other utilities. */
void linenoiseClearScreen(void);
void linenoiseSetMultiLine(int ml);
void linenoisePrintKeyCodes(void);
void linenoiseMaskModeEnable(void);
void linenoiseMaskModeDisable(void);
#ifdef __cplusplus
}
#endif
#endif /* __LINENOISE_H */

View file

@ -19,12 +19,14 @@
#include <cstring> #include <cstring>
#include <filesystem> #include <filesystem>
#include <iostream> #include <iostream>
#include <list>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "common.h" #include "common.h"
#include "json.hpp" #include "json.hpp"
#include "linenoise.cpp/linenoise.h"
#include "llama-cpp.h" #include "llama-cpp.h"
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
@ -536,7 +538,7 @@ class LlamaData {
llama_sampler_ptr sampler; llama_sampler_ptr sampler;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_chat_message> messages; std::vector<llama_chat_message> messages;
std::vector<std::string> msg_strs; std::list<std::string> msg_strs;
std::vector<char> fmtted; std::vector<char> fmtted;
int init(Opt & opt) { int init(Opt & opt) {
@ -807,24 +809,44 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
batch = llama_batch_get_one(&new_token_id, 1); batch = llama_batch_get_one(&new_token_id, 1);
} }
printf("\033[0m");
return 0; return 0;
} }
static int read_user_input(std::string & user) { static int read_user_input(std::string & user_input) {
std::getline(std::cin, user); static const char * prompt_prefix = "> ";
#ifdef WIN32
printf(
"\r%*s"
"\r\033[0m%s",
get_terminal_width(), " ", prompt_prefix);
std::getline(std::cin, user_input);
if (std::cin.eof()) { if (std::cin.eof()) {
printf("\n"); printf("\n");
return 1; return 1;
} }
#else
if (user == "/bye") { std::unique_ptr<char, decltype(&std::free)> line(const_cast<char *>(linenoise(prompt_prefix)), free);
if (!line) {
return 1; return 1;
} }
if (user.empty()) { user_input = line.get();
#endif
if (user_input == "/bye") {
return 1;
}
if (user_input.empty()) {
return 2; return 2;
} }
#ifndef WIN32
linenoiseHistoryAdd(line.get());
#endif
return 0; // Should have data in happy path return 0; // Should have data in happy path
} }
@ -865,10 +887,6 @@ static int handle_user_input(std::string & user_input, const std::string & user)
return 0; // No need for interactive input return 0; // No need for interactive input
} }
printf(
"\r%*s"
"\r\033[32m> \033[0m",
get_terminal_width(), " ");
return read_user_input(user_input); // Returns true if input ends the loop return read_user_input(user_input); // Returns true if input ends the loop
} }

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,7 @@
#include "loading.html.hpp" #include "loading.html.hpp"
#include <atomic> #include <atomic>
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <cstddef> #include <cstddef>
#include <cinttypes> #include <cinttypes>
@ -32,6 +33,8 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
enum stop_type { enum stop_type {
STOP_TYPE_NONE, STOP_TYPE_NONE,
STOP_TYPE_EOS, STOP_TYPE_EOS,
@ -1602,6 +1605,30 @@ struct server_response {
// should never reach here // should never reach here
} }
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
return !queue_results.empty();
});
if (!cr_res) {
return nullptr;
}
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
// single-task version of recv() // single-task version of recv()
server_task_result_ptr recv(int id_task) { server_task_result_ptr recv(int id_task) {
std::unordered_set<int> id_tasks = {id_task}; std::unordered_set<int> id_tasks = {id_task};
@ -2322,10 +2349,21 @@ struct server_context {
void receive_multi_results( void receive_multi_results(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler, const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
const std::function<void(json)> & error_handler) { const std::function<void(json)> & error_handler,
const std::function<bool()> & is_connection_closed) {
std::vector<server_task_result_ptr> results(id_tasks.size()); std::vector<server_task_result_ptr> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) { for (int i = 0; i < (int)id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv(id_tasks); server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
if (is_connection_closed()) {
cancel_tasks(id_tasks);
return;
}
if (result == nullptr) {
i--; // retry
continue;
}
if (result->is_error()) { if (result->is_error()) {
error_handler(result->to_json()); error_handler(result->to_json());
@ -2349,10 +2387,20 @@ struct server_context {
void receive_cmpl_results_stream( void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler, const std::function<bool(server_task_result_ptr&)> & result_handler,
const std::function<void(json)> & error_handler) { const std::function<void(json)> & error_handler,
const std::function<bool()> & is_connection_closed) {
size_t n_finished = 0; size_t n_finished = 0;
while (true) { while (true) {
server_task_result_ptr result = queue_results.recv(id_tasks); server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
if (is_connection_closed()) {
cancel_tasks(id_tasks);
return;
}
if (result == nullptr) {
continue; // retry
}
if (result->is_error()) { if (result->is_error()) {
error_handler(result->to_json()); error_handler(result->to_json());
@ -3633,6 +3681,7 @@ int main(int argc, char ** argv) {
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type, server_task_type type,
json & data, json & data,
std::function<bool()> is_connection_closed,
httplib::Response & res, httplib::Response & res,
oaicompat_type oaicompat) { oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
@ -3694,7 +3743,7 @@ int main(int argc, char ** argv) {
} }
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
}); }, is_connection_closed);
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else { } else {
@ -3704,6 +3753,7 @@ int main(int argc, char ** argv) {
if (res_json.is_array()) { if (res_json.is_array()) {
for (const auto & res : res_json) { for (const auto & res : res_json) {
if (!server_sent_event(sink, "data", res)) { if (!server_sent_event(sink, "data", res)) {
// sending failed (HTTP connection closed), cancel the generation
return false; return false;
} }
} }
@ -3713,6 +3763,9 @@ int main(int argc, char ** argv) {
} }
}, [&](const json & error_data) { }, [&](const json & error_data) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}, [&sink]() {
// note: do not use req.is_connection_closed here because req is already destroyed
return !sink.is_writable();
}); });
if (oaicompat != OAICOMPAT_TYPE_NONE) { if (oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string ev_done = "data: [DONE]\n\n"; static const std::string ev_done = "data: [DONE]\n\n";
@ -3735,6 +3788,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_NONE); OAICOMPAT_TYPE_NONE);
}; };
@ -3744,6 +3798,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_COMPLETION); OAICOMPAT_TYPE_COMPLETION);
}; };
@ -3820,6 +3875,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_INFILL,
data, data,
req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
}; };
@ -3834,6 +3890,7 @@ int main(int argc, char ** argv) {
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
data, data,
req.is_connection_closed,
res, res,
OAICOMPAT_TYPE_CHAT); OAICOMPAT_TYPE_CHAT);
}; };
@ -3980,7 +4037,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
error = true; error = true;
}); }, req.is_connection_closed);
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} }
@ -4070,7 +4127,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
error = true; error = true;
}); }, req.is_connection_closed);
} }
if (error) { if (error) {

View file

@ -1,4 +1,5 @@
import pytest import pytest
import requests
import time import time
from openai import OpenAI from openai import OpenAI
from utils import * from utils import *
@ -405,3 +406,23 @@ def test_n_probs_post_sampling():
assert "bytes" in prob and type(prob["bytes"]) == list assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
def test_cancel_request():
global server
server.n_ctx = 4096
server.n_predict = -1
server.n_slots = 1
server.server_slots = True
server.start()
# send a request that will take a long time, but cancel it before it finishes
try:
server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
}, timeout=0.1)
except requests.exceptions.ReadTimeout:
pass # expected
# make sure the slot is free
time.sleep(1) # wait for HTTP_POLLING_SECONDS
res = server.make_request("GET", "/slots")
assert res.body[0]["is_processing"] == False

View file

@ -26,6 +26,9 @@ from re import RegexFlag
import wget import wget
DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30
class ServerResponse: class ServerResponse:
headers: dict headers: dict
status_code: int status_code: int
@ -88,7 +91,7 @@ class ServerProcess:
if "PORT" in os.environ: if "PORT" in os.environ:
self.server_port = int(os.environ["PORT"]) self.server_port = int(os.environ["PORT"])
def start(self, timeout_seconds: int = 10) -> None: def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
if "LLAMA_SERVER_BIN_PATH" in os.environ: if "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"] server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt": elif os.name == "nt":
@ -219,17 +222,18 @@ class ServerProcess:
path: str, path: str,
data: dict | Any | None = None, data: dict | Any | None = None,
headers: dict | None = None, headers: dict | None = None,
timeout: float | None = None,
) -> ServerResponse: ) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}" url = f"http://{self.server_host}:{self.server_port}{path}"
parse_body = False parse_body = False
if method == "GET": if method == "GET":
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers, timeout=timeout)
parse_body = True parse_body = True
elif method == "POST": elif method == "POST":
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data, timeout=timeout)
parse_body = True parse_body = True
elif method == "OPTIONS": elif method == "OPTIONS":
response = requests.options(url, headers=headers) response = requests.options(url, headers=headers, timeout=timeout)
else: else:
raise ValueError(f"Unimplemented method: {method}") raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse() result = ServerResponse()

View file

@ -95,11 +95,11 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
auto generate = [&](const std::string & prompt) { auto generate = [&](const std::string & prompt, bool is_first) {
std::string response; std::string response;
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens); std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); GGML_ABORT("failed to tokenize the prompt\n");
@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
// generate a response // generate a response
printf("\033[33m"); printf("\033[33m");
std::string response = generate(prompt); std::string response = generate(prompt, prev_len == 0);
printf("\n\033[0m"); printf("\n\033[0m");
// add the response to the messages // add the response to the messages

View file

@ -78,3 +78,40 @@ play the audio:
$ aplay output.wav $ aplay output.wav
``` ```
### Running the example with llama-server
Running this example with `llama-server` is also possible and requires two
server instances to be started. One will serve the LLM model and the other
will serve the voice decoder model.
The LLM model server can be started with the following command:
```console
$ ./build/bin/llama-server -m ./models/outetts-0.2-0.5B-q8_0.gguf --port 8020
```
And the voice decoder model server can be started using:
```console
./build/bin/llama-server -m ./models/wavtokenizer-large-75-f16.gguf --port 8021 --embeddings --pooling none
```
Then we can run [tts-outetts.py](tts-outetts.py) to generate the audio.
First create a virtual environment for python and install the required
dependencies (this in only required to be done once):
```console
$ python3 -m venv venv
$ source venv/bin/activate
(venv) pip install requests numpy
```
And then run the python script using:
```conole
(venv) python ./examples/tts/tts-outetts.py http://localhost:8020 http://localhost:8021 "Hello world"
spectrogram generated: n_codes: 90, n_embd: 1282
converting to audio ...
audio generated: 28800 samples
audio written to file "output.wav"
```
And to play the audio we can again use aplay or any other media player:
```console
$ aplay output.wav
```

View file

@ -3,6 +3,121 @@ import sys
#import struct #import struct
import requests import requests
import re import re
import struct
import numpy as np
from concurrent.futures import ThreadPoolExecutor
def fill_hann_window(size, periodic=True):
if periodic:
return np.hanning(size + 1)[:-1]
return np.hanning(size)
def irfft(n_fft, complex_input):
return np.fft.irfft(complex_input, n=n_fft)
def fold(buffer, n_out, n_win, n_hop, n_pad):
result = np.zeros(n_out)
n_frames = len(buffer) // n_win
for i in range(n_frames):
start = i * n_hop
end = start + n_win
result[start:end] += buffer[i * n_win:(i + 1) * n_win]
return result[n_pad:-n_pad] if n_pad > 0 else result
def process_frame(args):
l, n_fft, ST, hann = args
frame = irfft(n_fft, ST[l])
frame = frame * hann
hann2 = hann * hann
return frame, hann2
def embd_to_audio(embd, n_codes, n_embd, n_thread=4):
embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd)
n_fft = 1280
n_hop = 320
n_win = 1280
n_pad = (n_win - n_hop) // 2
n_out = (n_codes - 1) * n_hop + n_win
hann = fill_hann_window(n_fft, True)
E = np.zeros((n_embd, n_codes), dtype=np.float32)
for l in range(n_codes):
for k in range(n_embd):
E[k, l] = embd[l, k]
half_embd = n_embd // 2
S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64)
for k in range(half_embd):
for l in range(n_codes):
mag = E[k, l]
phi = E[k + half_embd, l]
mag = np.clip(np.exp(mag), 0, 1e2)
S[l, k] = mag * np.exp(1j * phi)
res = np.zeros(n_codes * n_fft)
hann2_buffer = np.zeros(n_codes * n_fft)
with ThreadPoolExecutor(max_workers=n_thread) as executor:
args = [(l, n_fft, S, hann) for l in range(n_codes)]
results = list(executor.map(process_frame, args))
for l, (frame, hann2) in enumerate(results):
res[l*n_fft:(l+1)*n_fft] = frame
hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2
audio = fold(res, n_out, n_win, n_hop, n_pad)
env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad)
mask = env > 1e-10
audio[mask] /= env[mask]
return audio
def save_wav(filename, audio_data, sample_rate):
num_channels = 1
bits_per_sample = 16
bytes_per_sample = bits_per_sample // 8
data_size = len(audio_data) * bytes_per_sample
byte_rate = sample_rate * num_channels * bytes_per_sample
block_align = num_channels * bytes_per_sample
chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
chunk_size,
b'WAVE',
b'fmt ',
16, # fmt chunk size
1, # audio format (PCM)
num_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
data_size
)
audio_data = np.clip(audio_data * 32767, -32768, 32767)
pcm_data = audio_data.astype(np.int16)
with open(filename, 'wb') as f:
f.write(header)
f.write(pcm_data.tobytes())
def process_text(text: str): def process_text(text: str):
text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
@ -170,6 +285,15 @@ n_embd = len(embd[0])
print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd)) print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
# post-process the spectrogram to convert to audio # post-process the spectrogram to convert to audio
# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
print('converting to audio ...') print('converting to audio ...')
print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python') audio = embd_to_audio(embd, n_codes, n_embd)
print('audio generated: %d samples' % len(audio))
filename = "output.wav"
sample_rate = 24000 # sampling rate
# zero out first 0.25 seconds
audio[:24000 // 4] = 0.0
save_wav(filename, audio, sample_rate)
print('audio written to file "%s"' % filename)

View file

@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true); prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
} }
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
std::vector<llama_token> result;
size_t start = 0;
size_t end = str.find(delimiter);
//first token is always a newline, as it was not previously added
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]);
start = end + delimiter.length();
end = str.find(delimiter, start);
}
// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(vocab, current_word, false, true);
if (tmp.size() > 0) {
result.push_back(tmp[0]);
}
return result;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
common_params params; common_params params;
@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
std::vector<llama_token> codes; std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;
// process prompt and generate voice codes // process prompt and generate voice codes
{ {
@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS // convert the input text into the necessary format expected by OuteTTS
{ {
std::string prompt_clean = process_text(params.prompt); std::string prompt_clean = process_text(params.prompt);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
int n_past = batch.n_tokens; int n_past = batch.n_tokens;
int n_decode = 0; int n_decode = 0;
bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) { while (n_decode <= n_predict) {
// prepare the next batch // prepare the next batch
common_batch_clear(batch); common_batch_clear(batch);
@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
continue; continue;
} }
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
common_sampler_accept(smpl[i], new_token_id, true); common_sampler_accept(smpl[i], new_token_id, true);

View file

@ -185,6 +185,9 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
# extra artifacts # extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})

View file

@ -203,6 +203,8 @@ extern "C" {
// Backend registry // Backend registry
// //
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Backend (reg) enumeration // Backend (reg) enumeration
GGML_API size_t ggml_backend_reg_count(void); GGML_API size_t ggml_backend_reg_count(void);
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);

View file

@ -1384,16 +1384,20 @@ extern "C" {
float scale, float scale,
float max_bias); float max_bias);
GGML_API struct ggml_tensor * ggml_soft_max_back( GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b,
float scale,
float max_bias);
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b,
float scale,
float max_bias);
// rotary position embedding // rotary position embedding
// if (mode & 1) - skip n_past elements (NOT SUPPORTED) // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
@ -1500,7 +1504,7 @@ extern "C" {
// rotary position embedding backward, i.e compute dx from dy // rotary position embedding backward, i.e compute dx from dy
// a - dy // a - dy
GGML_API struct ggml_tensor * ggml_rope_back( GGML_API struct ggml_tensor * ggml_rope_ext_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, // gradients of ggml_rope result struct ggml_tensor * a, // gradients of ggml_rope result
struct ggml_tensor * b, // positions struct ggml_tensor * b, // positions
@ -1515,6 +1519,23 @@ extern "C" {
float beta_fast, float beta_fast,
float beta_slow); float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_multi_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// clamp // clamp
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_clamp( GGML_API struct ggml_tensor * ggml_clamp(

View file

@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
return true; return true;
} }
// ops that return true for this function must not use restrict pointers for their backend implementations
static bool ggml_op_can_inplace(enum ggml_op op) { static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) { switch (op) {
case GGML_OP_SCALE: case GGML_OP_SCALE:
@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_UNARY: case GGML_OP_UNARY:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
return true; return true;
default: default:

View file

@ -208,7 +208,6 @@ extern "C" {
// Internal backend registry API // Internal backend registry API
GGML_API void ggml_backend_register(ggml_backend_reg_t reg); GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Add backend dynamic loading support to the backend // Add backend dynamic loading support to the backend

View file

@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
uint32_t utmp[4]; uint32_t utmp[4];
#ifdef __ARM_NEON #ifdef __ARM_FEATURE_SVE
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
uint32x2_t mins8 = { 0 };
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
svint32_t sumi1_1 = svdup_n_s32(0);
svint32_t sumi1_2 = svdup_n_s32(0);
svint32_t sumi2 = svdup_n_s32(0);
svint32_t sumi2_1 = svdup_n_s32(0);
svint32_t sumi2_2 = svdup_n_s32(0);
switch (vector_length) {
case 128:
{
for (int j = 0; j < QK_K/64; ++j) {
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4 += 32;
}
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
} break;
case 256:
case 512:
{
for (int j = 0; j < QK_K/64; ++j) {
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
}
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sumf;
#elif __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);

View file

@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes(
} }
} }
static void ggml_compute_forward_dup_q(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
size_t qk = ggml_blck_size(type);
const int64_t nr = ggml_nelements(src1) / qk;
// destination must be contiguous in the first dimension
GGML_ASSERT(nb10 == ggml_type_size(dst->type));
// must either have first dimension large enough to hold a row, or fully contiguous
GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
const int ith = params->ith;
const int nth = params->nth;
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) {
uint32_t i = ir * qk;
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
dequantize_row_q(
(const void *) ((char *) src0->data + x_offset),
(float *) ((char *) dst->data + dst_offset), qk);
}
}
static void ggml_compute_forward_dup( static void ggml_compute_forward_dup(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup(
} break; } break;
default: default:
{ {
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
ggml_compute_forward_dup_q(params, dst);
break;
}
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }
@ -6691,20 +6746,20 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[0];
const struct ggml_tensor * grad = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1];
assert(ggml_is_contiguous_1(grad)); assert(ggml_is_contiguous_1(grad));
assert(ggml_is_contiguous_1(src0)); assert(ggml_is_contiguous_1(src1));
assert(ggml_is_contiguous_1(dst)); assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst)); assert(ggml_are_same_shape(src1, dst));
assert(ggml_are_same_shape(src0, grad)); assert(ggml_are_same_shape(src1, grad));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nc = src0->ne[0]; const int nc = src1->ne[0];
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src1);
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -6716,7 +6771,7 @@ static void ggml_compute_forward_silu_back_f32(
for (int i1 = ir0; i1 < ir1; i1++) { for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_silu_backward_f32(nc, ggml_vec_silu_backward_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])), (float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])), (float *) ((char *) src1->data + i1*(src1->nb[1])),
(float *) ((char *) grad->data + i1*(grad->nb[1]))); (float *) ((char *) grad->data + i1*(grad->nb[1])));
#ifndef NDEBUG #ifndef NDEBUG
@ -6895,7 +6950,7 @@ static void ggml_compute_forward_norm_f32(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f); GGML_ASSERT(eps >= 0.0f);
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -6966,7 +7021,7 @@ static void ggml_compute_forward_rms_norm_f32(
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f); GGML_ASSERT(eps >= 0.0f);
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -7018,12 +7073,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -7042,8 +7098,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
const int64_t i12 = i02; const int64_t i12 = i02;
const int64_t i13 = i03; const int64_t i13 = i03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_float sum_xx = 0.0; ggml_float sum_xx = 0.0;
ggml_float sum_xdz = 0.0; ggml_float sum_xdz = 0.0;
@ -7066,9 +7122,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
{ {
// z = rms_norm(x) // z = rms_norm(x)
// //
// rms_norm(src0) = // rms_norm(src1) =
// scale( // scale(
// src0, // src1,
// div( // div(
// 1, // 1,
// sqrt( // sqrt(
@ -7076,13 +7132,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
// scale( // scale(
// sum( // sum(
// sqr( // sqr(
// src0)), // src1)),
// (1.0/N)), // (1.0/N)),
// eps)))); // eps))));
// postorder: // postorder:
// ## op args grad // ## op args grad
// 00 param src0 grad[#00] // 00 param src1 grad[#00]
// 01 const 1 // 01 const 1
// 02 sqr (#00) grad[#02] // 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03] // 03 sum (#02) grad[#03]
@ -7159,6 +7215,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms) // dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x); ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@ -7750,12 +7807,13 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10); GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02); GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne2 % ne02 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
@ -7797,6 +7855,10 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16; const int64_t blck_1 = 16;
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t bir = ir0; bir < ir1; bir += blck_1) { for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1); const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@ -7807,8 +7869,8 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2; const int64_t i02 = i2 / dps2;
const int64_t i03 = i3; const int64_t i03 = i3 / dps3;
//const int64_t i10 = i1; //const int64_t i10 = i1;
const int64_t i12 = i2; const int64_t i12 = i2;
@ -8906,9 +8968,9 @@ static void ggml_compute_forward_soft_max(
} }
// ggml_compute_forward_soft_max_back // ggml_compute_forward_soft_max_ext_back
static void ggml_compute_forward_soft_max_back_f32( static void ggml_compute_forward_soft_max_ext_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
@ -8921,6 +8983,14 @@ static void ggml_compute_forward_soft_max_back_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src1, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int ith = params->ith; const int ith = params->ith;
@ -8969,10 +9039,11 @@ static void ggml_compute_forward_soft_max_back_f32(
// linear runtime, no additional memory // linear runtime, no additional memory
float dot_y_dy = 0; float dot_y_dy = 0;
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy); ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy); ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y); ggml_vec_mul_f32 (nc, dx, dx, y);
ggml_vec_scale_f32(nc, dx, scale);
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
@ -8983,7 +9054,7 @@ static void ggml_compute_forward_soft_max_back_f32(
} }
} }
static void ggml_compute_forward_soft_max_back( static void ggml_compute_forward_soft_max_ext_back(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
@ -8992,7 +9063,7 @@ static void ggml_compute_forward_soft_max_back(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_soft_max_back_f32(params, dst); ggml_compute_forward_soft_max_ext_back_f32(params, dst);
} break; } break;
default: default:
{ {
@ -9985,9 +10056,10 @@ static void ggml_compute_forward_im2col_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -10009,11 +10081,11 @@ static void ggml_compute_forward_im2col_back_f32(
const int64_t IH = is_2D ? ne1 : 1; const int64_t IH = is_2D ? ne1 : 1;
const int64_t IW = ne0; const int64_t IW = ne0;
const int64_t KH = is_2D ? ne01 : 1; const int64_t KH = is_2D ? ne11 : 1;
const int64_t KW = ne00; const int64_t KW = ne10;
const int64_t OH = is_2D ? ne12 : 1; const int64_t OH = is_2D ? ne02 : 1;
const int64_t OW = ne11; const int64_t OW = ne01;
int ofs0 = is_2D ? nb3 : nb2; int ofs0 = is_2D ? nb3 : nb2;
int ofs1 = is_2D ? nb2 : nb1; int ofs1 = is_2D ? nb2 : nb1;
@ -10059,9 +10131,9 @@ static void ggml_compute_forward_im2col_back_f32(
continue; continue;
} }
const float * const src_data = (const float *) src1->data const float * const grad_in = (const float *) src0->data
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
} }
} }
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@ -12484,22 +12556,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
const struct ggml_tensor * opt0 = dst->src[2]; const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
const int64_t ith = params->ith; const int64_t ith = params->ith;
const int64_t nth = params->nth; const int64_t nth = params->nth;
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0]; const int64_t nc = src0f->ne[0];
const int64_t nr = ggml_nrows(src0); const int64_t nr = ggml_nrows(src0f);
// rows per thread // rows per thread
const int64_t dr = (nr + nth - 1)/nth; const int64_t dr = (nr + nth - 1)/nth;
@ -12508,12 +12580,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const int64_t ir0 = dr*ith; const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr); const int64_t ir1 = MIN(ir0 + dr, nr);
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
for (int64_t i1 = ir0; i1 < ir1; i1++) { for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
#ifndef NDEBUG #ifndef NDEBUG
for (int64_t i = 0; i < nc; ++i) { for (int64_t i = 0; i < nc; ++i) {
@ -12526,11 +12598,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
// soft_max // soft_max
float max = -INFINITY; float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0); ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
assert(sum > 0.0); assert(sum > 0.0);
ggml_vec_scale_f32(nc, ds0, 1.0/sum); ggml_vec_scale_f32(nc, ds0, 1.0/sum);
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_scale_f32(nc, ds0, d_by_nr); ggml_vec_scale_f32(nc, ds0, d_by_nr);
@ -12827,7 +12899,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
{ {
ggml_compute_forward_soft_max_back(params, tensor); ggml_compute_forward_soft_max_ext_back(params, tensor);
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
@ -13668,6 +13740,7 @@ struct ggml_cplan ggml_graph_plan(
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{ {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break; } break;

View file

@ -403,8 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
case GGML_OP_ROPE_BACK: case GGML_OP_SOFT_MAX_BACK: {
return op->src[2] == NULL && (op->op_params[2] & 4) == 0; if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
return false;
}
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_IM2COL_BACK: case GGML_OP_IM2COL_BACK:
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:

View file

@ -5,95 +5,89 @@
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { template <bool use_shared>
const int warp_id = threadIdx.x / WARP_SIZE; static __global__ void cross_entropy_loss_f32(
const int lane_id = threadIdx.x % WARP_SIZE; const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; extern __shared__ float tmp[];
const int ne_tmp = WARP_SIZE*nclasses; logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
extern __shared__ float tmp_all[];
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
// Each warp first loads ne_tmp logits/labels into shared memory:
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
const int ig = i0*nclasses + i; // ig == i global
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
}
// Each thread in the warp then calculates the cross entropy loss for a single row.
// TODO: pad in order to avoid shared memory bank conflicts.
// Find maximum for softmax: // Find maximum for softmax:
float max = -INFINITY; float max_logit = -INFINITY;
for (int i = 0; i < nclasses; ++i) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); const float val = logits[i];
max_logit = fmaxf(max_logit, val);
if (use_shared) {
tmp[i] = val;
}
} }
max_logit = warp_reduce_max(max_logit);
// Calculate log(softmax(logits)) which is just logits - max: // Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f; float sum = 0.0f;
for (int i = 0; i < nclasses; ++i) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
float val = tmp_logits[lane_id*nclasses + i] - max; const float logit_i = use_shared ? tmp[i] : logits[i];
sum += expf(val); sum += expf(logit_i - max_logit);
tmp_logits[lane_id*nclasses + i] = val;
} }
sum = warp_reduce_sum(sum);
sum = logf(sum); sum = logf(sum);
// log(exp(logits - max) / sum) = (logits - max) - log(sum) // log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f; float loss = 0.0f;
for (int i = 0; i < nclasses; ++i) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; const float logit_i = use_shared ? tmp[i] : logits[i];
loss += (logit_i - max_logit - sum) * labels[i];
} }
loss = -warp_reduce_sum(loss) / (float)k; loss = -warp_reduce_sum(loss) / (float)k;
__syncthreads(); if (threadIdx.x != 0) {
if (lane_id == 0) {
tmp_all[warp_id] = loss;
}
__syncthreads();
if (warp_id != 0) {
return;
}
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
loss = warp_reduce_sum(loss);
if (lane_id != 0) {
return; return;
} }
dst[blockIdx.x] = loss; dst[blockIdx.x] = loss;
} }
static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) { template <bool use_shared>
static __global__ void cross_entropy_loss_back_f32(
const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
float * __restrict__ dst, const int nclasses) {
extern __shared__ float tmp[]; extern __shared__ float tmp[];
logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
dst += int64_t(blockIdx.x)*nclasses;
float maxval = -INFINITY; float maxval = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = logits[blockIdx.x*nclasses + i]; const float val = logits[i];
maxval = fmaxf(maxval, val); maxval = fmaxf(maxval, val);
tmp[i] = val;
if (use_shared) {
tmp[i] = val;
}
} }
maxval = warp_reduce_max(maxval); maxval = warp_reduce_max(maxval);
float sum = 0.0f; float sum = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = expf(tmp[i] - maxval); const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
sum += val; sum += val;
tmp[i] = val;
if (use_shared) {
tmp[i] = val;
} else {
dst[i] = val;
}
} }
sum = warp_reduce_sum(sum); sum = warp_reduce_sum(sum);
const float sm_scale = 1.0f/sum; const float sm_scale = 1.0f/sum;
const float d_by_nrows = *loss/gridDim.x; const float d_by_nrows = *grad/gridDim.x;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows; const float val = use_shared ? tmp[i] : dst[i];
dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
} }
} }
@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool & pool = ctx.pool(); ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); const dim3 blocks_num(nrows, 1, 1);
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); const size_t nbytes_shared = ne00*sizeof(float);
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x); ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
}
CUDA_CHECK(cudaGetLastError());
// Combine results from individual blocks: // Combine results from individual blocks:
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
} }
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * grad = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src0f = dst->src[1];
const ggml_tensor * opt0 = dst->src[2]; const ggml_tensor * src1f = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1f->type == GGML_TYPE_F32);
GGML_ASSERT(opt0->type == GGML_TYPE_F32); GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_scalar(grad));
GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1)); GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0f, dst));
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0f);
const float * src0_d = (const float *) src0->data; const float * grad_d = (const float *) grad->data;
const float * src1_d = (const float *) src1->data; const float * src0f_d = (const float *) src0f->data;
const float * opt0_d = (const float *) opt0->data; const float * src1f_d = (const float *) src1f->data;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(WARP_SIZE, 1, 1); const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1); const dim3 blocks_num(nrows, 1, 1);
const int shmem = ne00*sizeof(float); const size_t nbytes_shared = ne00*sizeof(float);
cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00); const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
}
} }

View file

@ -3,15 +3,15 @@
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows( static __global__ void k_get_rows(
const void * src0, const int32_t * src1, dst_t * dst, const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
@ -22,10 +22,10 @@ static __global__ void k_get_rows(
const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
@ -39,15 +39,15 @@ static __global__ void k_get_rows(
template<typename src0_t, typename dst_t> template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float( static __global__ void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst, const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
const int i00 = blockIdx.x*blockDim.x + threadIdx.x; const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = src0_row[i00]; dst_row[i00] = src0_row[i00];
} }
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, static void get_rows_cuda(
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
GGML_ASSERT(ne00 % 2 == 0); GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>( k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/ ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/); s10, s11, s12/*, s13*/);
GGML_UNUSED(dst); GGML_UNUSED(dst);
} }
template<typename src0_t> template<typename src0_t>
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, static void get_rows_cuda_float(
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ne13 == 1);
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12); const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
//const size_t s13 = nb13 / ggml_element_size(src1); //const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>( k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/ ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/); s10, s11, s12/*, s13*/);
GGML_UNUSED(dst); GGML_UNUSED(dst);
} }
@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data; const void * src0_d = (const void *) src0->data;
float * dst_d = (float *)dst->data; const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break; break;
default: default:
// TODO: k-quants // TODO: k-quants
@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
break; break;
} }
} }
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
GGML_TENSOR_BINARY_OP_LOCALS
const float * src0_d = (const float *) src0->data;
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne02*ne03 == 1);
GGML_ASSERT(ne12*ne13 == 1);
GGML_ASSERT(ne2*ne3 == 1);
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
}

View file

@ -1,5 +1,8 @@
#include "common.cuh" #include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256 #define CUDA_GET_ROWS_BLOCK_SIZE 256
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
ggml_cuda_op_get_rows(ctx, dst); ggml_cuda_op_get_rows(ctx, dst);
break; break;
case GGML_OP_GET_ROWS_BACK:
ggml_cuda_op_get_rows_back(ctx, dst);
break;
case GGML_OP_DUP: case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst); ggml_cuda_dup(ctx, dst);
break; break;
@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
ggml_cuda_op_leaky_relu(ctx, dst); ggml_cuda_op_leaky_relu(ctx, dst);
break; break;
case GGML_OP_SILU_BACK:
ggml_cuda_op_silu_back(ctx, dst);
break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
ggml_cuda_op_rms_norm(ctx, dst); ggml_cuda_op_rms_norm(ctx, dst);
break; break;
case GGML_OP_RMS_NORM_BACK:
ggml_cuda_op_rms_norm_back(ctx, dst);
break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@ -2138,9 +2147,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
ggml_cuda_op_soft_max(ctx, dst); ggml_cuda_op_soft_max(ctx, dst);
break; break;
case GGML_OP_SOFT_MAX_BACK:
ggml_cuda_op_soft_max_back(ctx, dst);
break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst); ggml_cuda_op_rope(ctx, dst);
break; break;
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst); ggml_cuda_op_im2col(ctx, dst);
break; break;
@ -2909,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
} }
} break; } break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
switch (op->src[0]->type) { switch (op->src[0]->type) {
@ -2925,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false; return false;
} }
} break; } break;
case GGML_OP_GET_ROWS_BACK:
{
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
} break;
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
@ -2998,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
} }
return false; return false;
} break; } break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]);
break;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break; break;
case GGML_OP_NONE: case GGML_OP_NONE:
@ -3024,8 +3047,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
return true; return true;
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_ROPE: case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]); case GGML_OP_ROPE_BACK: {
const size_t ts = ggml_type_size(op->src[0]->type);
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
}
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM: case GGML_OP_SUM:
@ -3081,6 +3113,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
return op->ne[1]; return op->ne[1];
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
return op->ne[2]; return op->ne[2];
default: default:
return ggml_nrows(op); return ggml_nrows(op);

View file

@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f); x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col]; const float xi = x[col];
mean_var.x += xi; mean_var.x += xi;
mean_var.y += xi * xi; mean_var.y += xi * xi;
} }
// sum up partial sums // sum up partial sums
mean_var = warp_reduce_sum(mean_var); mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) { if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float2 s_sum[32]; __shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) { if (lane_id == 0) {
s_sum[warp_id] = mean_var; s_sum[warp_id] = mean_var;
} }
@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
const float inv_std = rsqrtf(var + eps); const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; dst[col] = (x[col] - mean) * inv_std;
} }
} }
@ -40,14 +44,8 @@ template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx // blockIdx.x: num_groups idx
// threadIdx.x: block_size idx // threadIdx.x: block_size idx
int start = blockIdx.x * group_size; const int start = blockIdx.x*group_size + threadIdx.x;
int end = start + group_size; const int end = min(blockIdx.x*group_size + group_size, ne_elements);
start += threadIdx.x;
if (end >= ne_elements) {
end = ne_elements;
}
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
} }
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) { if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32]; __shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) { if (lane_id == 0) {
s_sum[warp_id] = tmp; s_sum[warp_id] = tmp;
} }
@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
} }
float mean = tmp / group_size; const float mean = tmp / group_size;
tmp = 0.0f; tmp = 0.0f;
for (int j = start; j < end; j += block_size) { for (int j = start; j < end; j += block_size) {
float xi = x[j] - mean; const float xi = x[j] - mean;
dst[j] = xi; dst[j] = xi;
tmp += xi * xi; tmp += xi * xi;
} }
@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
__shared__ float s_sum[32]; __shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) { if (lane_id == 0) {
s_sum[warp_id] = tmp; s_sum[warp_id] = tmp;
} }
@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
} }
float variance = tmp / group_size; const float variance = tmp / group_size;
float scale = rsqrtf(variance + eps); const float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) { for (int j = start; j < end; j += block_size) {
dst[j] *= scale; dst[j] *= scale;
} }
@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col]; const float xi = x[col];
tmp += xi * xi; tmp += xi * xi;
} }
// sum up partial sums // sum up partial sums
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) { if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32]; __shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) { if (lane_id == 0) {
s_sum[warp_id] = tmp; s_sum[warp_id] = tmp;
} }
@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
const float scale = rsqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col]; dst[col] = scale * x[col];
}
}
template <int block_size>
static __global__ void rms_norm_back_f32(
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
grad += int64_t(row)*ncols;
xf += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
for (int col = tid; col < ncols; col += block_size) {
const float xfi = xf[col];
sum_xx += xfi * xfi;
sum_xg += xfi * grad[col];
}
// sum up partial sums
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = warp_reduce_sum(sum_xg);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum_xx[32];
__shared__ float s_sum_xg[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum_xx[warp_id] = sum_xx;
s_sum_xg[warp_id] = sum_xg;
}
__syncthreads();
sum_xx = s_sum_xx[lane_id];
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = s_sum_xg[lane_id];
sum_xg = warp_reduce_sum(sum_xg);
}
const float mean_eps = sum_xx / ncols + eps;
const float sum_eps = sum_xx + ncols*eps;
const float scale_grad = rsqrtf(mean_eps);
const float scale_x = -scale_grad * sum_xg/sum_eps;
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale_grad*grad[col] + scale_x*xf[col];
} }
} }
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
} }
} }
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { static void group_norm_f32_cuda(
const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) { if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps); group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
} }
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps); rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
} }
} }
static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} }
@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float eps; float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float)); memcpy(&eps, dst->op_params + 1, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
} }
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
const float * grad_d = (const float *) grad->data;
const float * src0f_d = (const float *) src0f->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0f);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}

View file

@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne01 == ne11);
GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10); GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == src0->ne[2]); GGML_ASSERT(ne2 % src0->ne[2] == 0);
GGML_ASSERT(ne3 % src0->ne[3] == 0);
GGML_ASSERT(ne2 == src1->ne[2]); GGML_ASSERT(ne2 == src1->ne[2]);
GGML_ASSERT(ne3 == src0->ne[3]);
GGML_ASSERT(ne3 == src1->ne[3]); GGML_ASSERT(ne3 == src1->ne[3]);
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
GGML_ASSERT(ne2 == 1);
GGML_ASSERT(ne3 == 1);
CUBLAS_CHECK(cublasSetStream(handle, stream)); CUBLAS_CHECK(cublasSetStream(handle, stream));
const bool src1_T = ggml_is_transposed(src1); const bool src1_T = ggml_is_transposed(src1);
@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
CUBLAS_CHECK( // data strides in dimensions 2/3
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, const size_t s02 = nb02 / sizeof(float);
ne0, ne1, ne01, const size_t s03 = nb03 / sizeof(float);
&alpha, src0_d, ne00, const size_t s12 = nb12 / sizeof(float);
src1_d, ldb, const size_t s13 = nb13 / sizeof(float);
&beta, dst_d, ne0)); const size_t s2 = nb2 / sizeof(float);
const size_t s3 = nb3 / sizeof(float);
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
// TODO batched matrix multiplication
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
}
}
} }

View file

@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
template<bool forward>
static __device__ void rope_yarn( static __device__ void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
float * cos_theta, float * sin_theta) { float mscale, float & cos_theta, float & sin_theta) {
// Get n-d rotational scaling corrected for extrapolation // Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap; float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp; float theta = theta_interp;
@ -29,24 +30,28 @@ static __device__ void rope_yarn(
// Get n-d magnitude scaling corrected for interpolation // Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
} }
*cos_theta = cosf(theta) * mscale; cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale; sin_theta = sinf(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
} }
template<typename T, bool has_ff> template<bool forward, bool has_ff, typename T>
static __global__ void rope_norm( static __global__ void rope_norm(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) { if (i0 >= n_dims) {
const int i = row*ne0 + i0; const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0]; dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1]; dst[i + 1] = x[i + 1];
@ -54,39 +59,43 @@ static __global__ void rope_norm(
return; return;
} }
const int i = row*ne0 + i0; const int row_x = row_dst % ne1;
const int i2 = row/p_delta_rows; const int channel_x = row_dst / ne1;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0]; const float x0 = x[ix + 0];
const float x1 = x[i + 1]; const float x1 = x[ix + 1];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[i + 1] = x0*sin_theta + x1*cos_theta; dst[idst + 1] = x0*sin_theta + x1*cos_theta;
} }
template<typename T, bool has_ff> template<bool forward, bool has_ff, typename T>
static __global__ void rope_neox( static __global__ void rope_neox(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) { if (i0 >= n_dims) {
const int i = row*ne0 + i0; const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0]; dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1]; dst[i + 1] = x[i + 1];
@ -94,39 +103,43 @@ static __global__ void rope_neox(
return; return;
} }
const int i = row*ne0 + i0/2; const int row_x = row_dst % ne1;
const int i2 = row/p_delta_rows; const int channel_x = row_dst / ne1;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0]; const float x0 = x[ix + 0];
const float x1 = x[i + n_dims/2]; const float x1 = x[ix + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
template<typename T, bool has_ff> template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi( static __global__ void rope_multi(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) { if (i0 >= n_dims) {
const int i = row*ne0 + i0; const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0]; dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1]; dst[i + 1] = x[i + 1];
@ -134,25 +147,28 @@ static __global__ void rope_multi(
return; return;
} }
const int i = row*ne0 + i0/2; const int row_x = row_dst % ne1;
const int i2 = row/p_delta_rows; const int channel_x = row_dst / ne1;
int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int idst = row_dst*ne0 + i0/2;
int sec_w = sections.v[1] + sections.v[0]; const int ix = channel_x*s2 + row_x*s1 + i0/2;
int sector = (i0 / 2) % sect_dims;
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0; float theta_base = 0.0;
if (sector < sections.v[0]) { if (sector < sections.v[0]) {
theta_base = pos[i2]*powf(theta_scale, i0/2.0f); theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
} }
else if (sector >= sections.v[0] && sector < sec_w) { else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f); theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} }
else if (sector >= sec_w && sector < sec_w + sections.v[2]) { else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f); theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} }
else if (sector >= sec_w + sections.v[2]) { else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f); theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
} }
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -160,42 +176,46 @@ static __global__ void rope_multi(
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0]; const float x0 = x[ix + 0];
const float x1 = x[i + n_dims/2]; const float x1 = x[ix + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
template<typename T, bool has_ff> template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision( static __global__ void rope_vision(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ne0 + i0/2; const int row_x = row_dst % ne1;
const int i2 = row/p_delta_rows; // i2-th tokens const int channel_x = row_dst / ne1;
int sect_dims = sections.v[0] + sections.v[1]; const int idst = row_dst*ne0 + i0/2;
int sec_w = sections.v[1] + sections.v[0]; const int ix = channel_x*s2 + row_x*s1 + i0/2;
int sector = (i0 / 2) % sect_dims;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0; float theta_base = 0.0;
if (sector < sections.v[0]) { if (sector < sections.v[0]) {
const int p = sector; const int p = sector;
theta_base = pos[i2]*powf(theta_scale, p); theta_base = pos[channel_x]*powf(theta_scale, p);
} }
else if (sector >= sections.v[0] && sector < sec_w) { else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0]; const int p = sector - sections.v[0];
theta_base = pos[i2 + ne2]*powf(theta_scale, p); theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
} }
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -203,19 +223,20 @@ static __global__ void rope_vision(
float cos_theta; float cos_theta;
float sin_theta; float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0]; const float x0 = x[ix + 0];
const float x1 = x[i + n_dims]; const float x1 = x[ix + n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
} }
template<typename T> template<bool forward, typename T>
static void rope_norm_cuda( static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@ -224,22 +245,21 @@ static void rope_norm_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors attn_factor, corr_dims, theta_scale, freq_factors);
);
} else { } else {
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors attn_factor, corr_dims, theta_scale, freq_factors);
);
} }
} }
template<typename T> template<bool forward, typename T>
static void rope_neox_cuda( static void rope_neox_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@ -248,22 +268,21 @@ static void rope_neox_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors attn_factor, corr_dims, theta_scale, freq_factors);
);
} else { } else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors attn_factor, corr_dims, theta_scale, freq_factors);
);
} }
} }
template<typename T> template<bool forward, typename T>
static void rope_multi_cuda( static void rope_multi_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@ -272,22 +291,21 @@ static void rope_multi_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors, sections attn_factor, corr_dims, theta_scale, freq_factors, sections);
);
} else { } else {
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors, sections attn_factor, corr_dims, theta_scale, freq_factors, sections);
);
} }
} }
template<typename T> template<bool forward, typename T>
static void rope_vision_cuda( static void rope_vision_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@ -298,80 +316,18 @@ static void rope_vision_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>( rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors, sections attn_factor, corr_dims, theta_scale, freq_factors, sections);
);
} else { } else {
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>( rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
theta_scale, freq_factors, sections attn_factor, corr_dims, theta_scale, freq_factors, sections);
);
} }
} }
static void rope_norm_cuda_f16( template <bool forward>
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_norm_cuda_f32(
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_multi_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_multi_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2]; const ggml_tensor * src2 = dst->src[2];
@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data; float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type); GGML_ASSERT(src0->type == dst->type);
@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ne02 = src0->ne[2]; // num heads const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0); const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// compute // compute
if (is_neox) { if (is_neox) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32( rope_neox_cuda<forward>(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16( rope_neox_cuda<forward>(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else if (is_mrope && !is_vision) { } else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda_f32( rope_multi_cuda<forward>(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, sections, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda_f16( rope_multi_cuda<forward>(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, sections, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else if (is_vision) { } else if (is_vision) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda_f32( rope_vision_cuda<forward>(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, sections, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda_f16( rope_vision_cuda<forward>(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, sections, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32( rope_norm_cuda<forward>(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_f16( rope_norm_cuda<forward>(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
attn_factor, corr_dims, freq_factors, stream freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }
} }
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<true>(ctx, dst);
}
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst);
}

View file

@ -3,3 +3,5 @@
#define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1,5 +1,7 @@
#include "common.cuh" #include "common.cuh"
#include "ggml.h"
#include "softmax.cuh" #include "softmax.cuh"
#include <cstdint>
template <typename T> template <typename T>
static __device__ __forceinline__ float t2f32(T val) { static __device__ __forceinline__ float t2f32(T val) {
@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val); return __half2float(val);
} }
template <bool vals_smem, int ncols_template, int block_size_template, typename T> template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int rowx = blockIdx.x; const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
x += int64_t(rowx)*ncols;
mask += int64_t(rowy)*ncols * (mask != nullptr);
dst += int64_t(rowx)*ncols;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations: // shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = -INFINITY; float max_val = -INFINITY;
@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
break; break;
} }
const int64_t ix = (int64_t)rowx*ncols + col; const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
const int64_t iy = (int64_t)rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val; vals[col] = val;
max_val = max(max_val, val); max_val = max(max_val, val);
@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
return; return;
} }
const int64_t idst = (int64_t)rowx*ncols + col; dst[col] = vals[col] * inv_sum;
dst[idst] = vals[col] * inv_sum; }
}
static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
grad += int64_t(rowx)*ncols;
dstf += int64_t(rowx)*ncols;
dst += int64_t(rowx)*ncols;
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
for (int col = tid; col < ncols; col += WARP_SIZE) {
dgf_dot += dstf[col]*grad[col];
}
dgf_dot = warp_reduce_sum(dgf_dot);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
} }
} }
@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1); const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head = nrows_x/nrows_y; const uint32_t n_head = nrows_x/nrows_y;
@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer // FIXME: this limit could be raised by ~2-4x on Ampere or newer
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) { switch (ncols_x) {
case 32: case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 64: case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 128: case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 256: case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 512: case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 1024: case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 2048: case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
case 4096: case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
default: default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break; break;
} }
} else { } else {
const size_t shmem_low = WARP_SIZE*sizeof(float); const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
} }
} }
static void soft_max_back_f32_cuda(
const float * grad, const float * dstf, float * dst,
const int ncols, const int nrows, const float scale, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *)src1->data : nullptr; const void * src1_d = src1 ? (const void *) src1->data : nullptr;
float * dst_d = (float *) dst->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float scale = 1.0f; float scale = 1.0f;
float max_bias = 0.0f; float max_bias = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) { if (use_f16) {
const half * src1_dd = (const half *)src1_d; soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} else { } else {
const float * src1_dd = (const float *)src1_d; soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} }
} }
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // grad
const ggml_tensor * src1 = dst->src[1]; // forward pass output
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
}

View file

@ -3,3 +3,5 @@
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024 #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i])); dst[i] = x[i] / (1.0f + expf(-x[i]));
} }
static __global__ void silu_back_f32(
const float * grad, const float * xf, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const float xfi = xf[i];
const float s = 1.0f / (1.0f + expf(-xfi));
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) { static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { if (i >= k) {
@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k); tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
} }
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // input from forward pass
const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;

View file

@ -4,6 +4,7 @@
#define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_STEP_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_SILU_BACK_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256 #define CUDA_SIGMOID_BLOCK_SIZE 256
@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -29,5 +29,6 @@
#include "wkv6.hpp" #include "wkv6.hpp"
#include "outprod.hpp" #include "outprod.hpp"
#include "element_wise.hpp" #include "element_wise.hpp"
#include "gla.hpp"
#endif // GGML_SYCL_BACKEND_HPP #endif // GGML_SYCL_BACKEND_HPP

View file

@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
// pool // pool
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device); static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
ggml_sycl_pool & pool(int device) { ggml_sycl_pool & pool(int device) {
if (pools[device] == nullptr) { if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(stream(device,0), device); pools[device] = new_pool_for_device(stream(device,0), device);
@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
ggml_sycl_pool & pool() { ggml_sycl_pool & pool() {
return pool(device); return pool(device);
} }
ggml_sycl_pool & host_pool(int device) {
if (host_pools[device] == nullptr) {
host_pools[device] = new_pool_for_host(stream(device, 0), device);
}
return *host_pools[device];
}
ggml_sycl_pool & host_pool() { return host_pool(device); }
}; };
// common device functions // common device functions

View file

@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
return device_type.str(); return device_type.str();
} }
template <typename Ts> struct matrix_info_t {
oneapi::mkl::transpose transpose_info[2];
Ts value_info[2];
std::int64_t size_info[3];
std::int64_t ld_info[3];
std::int64_t groupsize_info;
};
namespace dpct namespace dpct
{ {
typedef sycl::queue *queue_ptr; typedef sycl::queue *queue_ptr;
@ -1727,26 +1735,13 @@ namespace dpct
}; };
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
oneapi::mkl::transpose b_trans, int m, int n, int k, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
const void *alpha, const void **a, int lda, int ldb, const void * beta, void ** c, int ldc, int batch_size,
const void **b, int ldb, const void *beta, void **c, matrix_info_t<float> * matrix_info) {
int ldc, int batch_size)
{
struct matrix_info_t
{
oneapi::mkl::transpose transpose_info[2];
Ts value_info[2];
std::int64_t size_info[3];
std::int64_t ld_info[3];
std::int64_t groupsize_info;
};
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
matrix_info_t *matrix_info =
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[0] = a_trans;
matrix_info->transpose_info[1] = b_trans; matrix_info->transpose_info[1] = b_trans;
matrix_info->value_info[0] = alpha_value; matrix_info->value_info[0] = alpha_value;
@ -1763,23 +1758,18 @@ namespace dpct
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a), matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1, reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
&(matrix_info->groupsize_info)); reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#else #else
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#endif #endif
q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(e);
cgh.host_task([=] { std::free(matrix_info); }); });
} }
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
@ -2422,25 +2412,11 @@ namespace dpct
/// \param [in] ldc Leading dimension of C. /// \param [in] ldc Leading dimension of C.
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
/// \param [in] scaling_type Data type of the scaling factors. /// \param [in] scaling_type Data type of the scaling factors.
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
oneapi::mkl::transpose b_trans, int m, int n, int k, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
const void *alpha, const void *a[], const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
library_data_t a_type, int lda, const void *b[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
library_data_t b_type, int ldb, const void *beta, matrix_info_t<float> * matrix_info) {
void *c[], library_data_t c_type, int ldc,
int batch_size, library_data_t scaling_type)
{
if (scaling_type == library_data_t::real_float &&
c_type == library_data_t::complex_float)
{
scaling_type = library_data_t::complex_float;
}
else if (scaling_type == library_data_t::real_double &&
c_type == library_data_t::complex_double)
{
scaling_type = library_data_t::complex_double;
}
std::uint64_t key = std::uint64_t key =
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
switch (key) switch (key)
@ -2449,48 +2425,24 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<float, float, float, float>( detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, beta, c, ldc, batch_size, matrix_info);
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double,
library_data_t::real_double, library_data_t::real_double): library_data_t::real_double, library_data_t::real_double):
{ {
detail::gemm_batch_impl<double, double, double, double>( detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, beta, c, ldc, batch_size, matrix_info);
batch_size);
break;
}
case detail::get_type_combination_id(
library_data_t::complex_float, library_data_t::complex_float,
library_data_t::complex_float, library_data_t::complex_float):
{
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
std::complex<float>, std::complex<float>>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
batch_size);
break;
}
case detail::get_type_combination_id(
library_data_t::complex_double, library_data_t::complex_double,
library_data_t::complex_double, library_data_t::complex_double):
{
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
std::complex<double>, std::complex<double>>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_half): library_data_t::real_half, library_data_t::real_half):
{ {
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
sycl::half>(q, a_trans, b_trans, m, n, k, alpha, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
a, lda, b, ldb, beta, c, ldc,
batch_size);
break; break;
} }
#ifdef __INTEL_MKL__ #ifdef __INTEL_MKL__
@ -2498,19 +2450,16 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float): library_data_t::real_bfloat16, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
b, ldb, beta, c, ldc, batch_size);
break; break;
} }
#endif #endif
@ -2522,10 +2471,9 @@ namespace dpct
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q); dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
float beta_float = float beta_float =
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q); dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
float>(q, a_trans, b_trans, m, n, k, &alpha_float, q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
a, lda, b, ldb, &beta_float, c, ldc, matrix_info);
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2533,8 +2481,7 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>( detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2542,8 +2489,7 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>( detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2557,8 +2503,7 @@ namespace dpct
sycl::half alpha_half(alpha_value); sycl::half alpha_half(alpha_value);
sycl::half beta_half(beta_value); sycl::half beta_half(beta_value);
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
batch_size);
break; break;
} }
default: default:

View file

@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
} }
}; };
struct ggml_sycl_pool_host : public ggml_sycl_pool {
queue_ptr qptr;
int device;
inline static int counter{ 0 };
struct ggml_sycl_buffer {
void * ptr = nullptr;
size_t size = 0;
};
// Set arbitrarly to 64
static constexpr int MAX_POOL_SIZE{ 64 };
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
size_t pool_size = 0;
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
~ggml_sycl_pool_host() {
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
b.ptr = nullptr;
pool_size -= b.size;
b.size = 0;
}
}
counter = 0;
}
void * alloc(size_t size, size_t * actual_size) override {
if (counter == MAX_POOL_SIZE) {
ggml_sycl_buffer b = buffer_pool[0];
void * ptr = b.ptr;
*actual_size = b.size;
counter = 1;
return ptr;
}
ggml_sycl_buffer & b = buffer_pool[counter];
if (b.ptr == nullptr) {
void * ptr;
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
if (!ptr) {
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
return nullptr;
}
pool_size += size;
*actual_size = size;
counter = counter + 1;
return ptr;
} else {
++counter;
b.size = size;
return b.ptr;
}
}
void free(void * ptr, size_t size) override {
// if the pool is not completed add the pointer to it in place of the first nullptr found.
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
return;
}
}
}
};
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
// return pool for the host to speed up memory management
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
}
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
// TBD: NO VMM support // TBD: NO VMM support
// if (ggml_sycl_info().devices[device].vmm) { // if (ggml_sycl_info().devices[device].vmm) {
@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13); sycl::range<3> block_dims(1, ne12, ne13);
/* /*
@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
}); });
} }
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::mkl::transpose::trans, *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **)(ptrs_src.get() + 0 * ne23), (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
dpct::library_data_t::real_half, nb01 / nb00, (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
(const void **)(ptrs_src.get() + 1 * ne23),
dpct::library_data_t::real_half, nb11 / nb10, beta,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type)));
} }
} }
catch (sycl::exception const &exc) { catch (sycl::exception const &exc) {
@ -4040,6 +4116,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst); ggml_sycl_op_rwkv_wkv6(ctx, dst);
break; break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
default: default:
return false; return false;
} }
@ -4507,6 +4586,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
return true; return true;
default: default:
return false; return false;

105
ggml/src/ggml-sycl/gla.cpp Normal file
View file

@ -0,0 +1,105 @@
#include <sycl/sycl.hpp>
#include "common.hpp"
template <u_int HEAD_SIZE>
static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale,
const float * k, const float * v, const float * r, const float * td,
const float * s, float * dst) {
const u_int head_size = HEAD_SIZE;
const u_int state_size = C * head_size;
const u_int n_seq_tokens = T / B;
sycl::range<1> block_dims((C / H));
sycl::range<1> grid_dims((B * H));
stream->submit([&](sycl::handler & cgh) {
/* local memory accessors*/
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
u_int tid = item.get_local_id(0);
u_int bid = item.get_group(0);
u_int batch_i = bid / H;
u_int head_i = bid % H;
float state[head_size];
#pragma unroll
for (u_int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
item.barrier(sycl::access::fence_space::local_space); //sync threads
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item.barrier(sycl::access::fence_space::local_space); //sync threads
const float _v = v[t];
float y = 0;
for (u_int j = 0; j < head_size; j += 4) {
const sycl::float4 & k = (sycl::float4 &) (_k[j]);
const sycl::float4 & r = (sycl::float4 &) (_r[j]);
const sycl::float4 & td = (sycl::float4 &) (_td[j]);
sycl::float4 & s = (sycl::float4 &) (state[j]);
sycl::float4 kv;
kv.x() = k.x() * _v;
kv.y() = k.y() * _v;
kv.z() = k.z() * _v;
kv.w() = k.w() * _v;
s.x() = s.x() * td.x() + kv.x();
s.y() = s.y() * td.y() + kv.y();
s.z() = s.z() * td.z() + kv.z();
s.w() = s.w() * td.w() + kv.w();
y += r.x() * s.x();
y += r.y() * s.y();
y += r.z() * s.z();
y += r.w() * s.w();
}
dst[t] = y * scale;
}
#pragma unroll
for (u_int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
});
});
}
void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const float * k_d = static_cast<const float *>(dst->src[0]->data);
const float * v_d = static_cast<const float *>(dst->src[1]->data);
const float * r_d = static_cast<const float *>(dst->src[2]->data);
const float * td_d = static_cast<const float *>(dst->src[3]->data);
const float * s_d = static_cast<const float *>(dst->src[4]->data);
const int64_t B = dst->src[4]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64 || C / H == 128);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float * dst_d = (float *) dst->data;
if (C / H == 64) {
gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
} else {
gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
}
}

View file

@ -0,0 +1,8 @@
#ifndef GGML_SYCL_GLA_HPP
#define GGML_SYCL_GLA_HPP
#include "common.hpp"
void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_GLA_HPP

View file

@ -1,5 +1,20 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED) find_package(Vulkan COMPONENTS glslc REQUIRED)
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
else()
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
endif()
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
if (Vulkan_FOUND) if (Vulkan_FOUND)
message(STATUS "Vulkan found") message(STATUS "Vulkan found")
@ -73,19 +88,56 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif() endif()
add_subdirectory(vulkan-shaders) if (NOT CMAKE_CROSSCOMPILING)
add_subdirectory(vulkan-shaders)
if (MSVC)
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
string(TOUPPER ${CONFIG} CONFIG)
set_target_properties(vulkan-shaders-gen PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()
else()
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
detect_host_compiler()
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
message(FATAL_ERROR "Host compiler not found")
else()
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) include(ExternalProject)
# Native build through ExternalProject_Add
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
)
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
endif()
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
if (NOT CMAKE_CROSSCOMPILING) if (CMAKE_CROSSCOMPILING)
set(_ggml_vk_genshaders_cmd "$<TARGET_FILE_DIR:vulkan-shaders-gen>/${_ggml_vk_genshaders_cmd}") set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
endif () endif()
add_custom_command( add_custom_command(
OUTPUT ${_ggml_vk_header} OUTPUT ${_ggml_vk_header}
@ -99,7 +151,7 @@ if (Vulkan_FOUND)
--target-cpp ${_ggml_vk_source} --target-cpp ${_ggml_vk_source}
--no-clean --no-clean
DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} DEPENDS ${_ggml_vk_shader_deps}
COMMENT "Generate vulkan shaders" COMMENT "Generate vulkan shaders"
) )

View file

@ -0,0 +1,15 @@
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_C_FLAGS -O2)
set(CMAKE_CXX_FLAGS -O2)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER)
set(CMAKE_C_COMPILER @HOST_C_COMPILER@)
set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@)
if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC")
foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()

View file

@ -228,6 +228,8 @@ struct vk_device_struct {
vk_pipeline pipeline_repeat_f32; vk_pipeline pipeline_repeat_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_f32;
@ -384,10 +386,13 @@ struct vk_flash_attn_push_constants {
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nb01;
uint32_t nb02; uint32_t nb02;
uint32_t nb03; uint32_t nb03;
uint32_t nb11;
uint32_t nb12; uint32_t nb12;
uint32_t nb13; uint32_t nb13;
uint32_t nb21;
uint32_t nb22; uint32_t nb22;
uint32_t nb23; uint32_t nb23;
uint32_t nb31; uint32_t nb31;
@ -1965,6 +1970,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@ -3689,6 +3708,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cpy_f16_f16; return ctx->device->pipeline_cpy_f16_f16;
} }
} }
if (src->type == GGML_TYPE_F32) {
switch (to) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return ctx->device->pipeline_cpy_f32_quant[to];
default:
break;
}
}
if (to == GGML_TYPE_F32) {
switch (src->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return ctx->device->pipeline_cpy_quant_f32[src->type];
default:
break;
}
}
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -4766,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
assert(pipelines); assert(pipelines);
bool aligned = (KV % pipelines[1]->align) == 0; const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
bool aligned = (KV % pipelines[1]->align) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
vk_pipeline pipeline = pipelines[aligned]; vk_pipeline pipeline = pipelines[aligned];
assert(pipeline); assert(pipeline);
@ -4802,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (ctx->device->uma) { if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
Q_uma = d_Q != nullptr; Q_uma = d_Q != nullptr;
K_uma = d_K != nullptr; K_uma = d_K != nullptr;
V_uma = d_V != nullptr; V_uma = d_V != nullptr;
D_uma = d_D != nullptr; D_uma = d_D != nullptr;
if (mask) { if (mask) {
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
M_uma = d_M != nullptr; M_uma = d_M != nullptr;
} }
} }
@ -4848,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@ -5160,7 +5224,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} }
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
GGML_ASSERT(dst->buffer != nullptr); GGML_ASSERT(dst->buffer != nullptr);
const uint64_t ne00 = src0->ne[0]; const uint64_t ne00 = src0->ne[0];
@ -7905,12 +7969,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true; if (src0_type == GGML_TYPE_F32) {
switch (src1_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
break;
}
} }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { if (src1_type == GGML_TYPE_F32) {
return true; switch (src0_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
break;
}
} }
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true; return true;
} }
@ -8601,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2]; ggml_tensor * src2 = tensor->src[2];
ggml_tensor * src3 = tensor->src[3];
void * tensor_data = tensor->data; void * tensor_data = tensor->data;
@ -8663,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) { if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
} }
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl; std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
@ -8707,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) { if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
} }
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl; std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
@ -8729,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) { if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
} }
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl; std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);

View file

@ -1,9 +1,11 @@
find_package (Threads REQUIRED) find_package (Threads REQUIRED)
find_package(Vulkan COMPONENTS glslc REQUIRED) find_program(GLSLC_EXECUTABLE glslc)
if(NOT GLSLC_EXECUTABLE)
message(FATAL_ERROR "glslc not found.")
endif()
set(TARGET vulkan-shaders-gen) set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp) add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)

View file

@ -0,0 +1,51 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = get_doffset() + dst_idx(idx);
uint src_idx = src0_idx_quant(idx, QUANT_K);
const uint a_offset = 0;
const uint ib = src_idx;
const vec2 dm = get_dm(ib, a_offset);
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
v = v * dm.x + vec4(dm.y);
#if QUANT_R == 2
data_d[dst_idx + j/2 + 0] = v[0];
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
data_d[dst_idx + j/2 + 1] = v[2];
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
#else
data_d[dst_idx + j + 0] = v[0];
data_d[dst_idx + j + 1] = v[1];
data_d[dst_idx + j + 2] = v[2];
data_d[dst_idx + j + 3] = v[3];
#endif
}
}

View file

@ -0,0 +1,237 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
const uint xi0 = min(15, int(x0 + 8.5));
const uint xi1 = min(15, int(x1 + 8.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q4_1)
void quantize(uint dst_idx, uint src_idx)
{
float vmin = 1.0/0.0;
float vmax = -vmin;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
const float v = data_s[src_idx + j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(vmin);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
const uint xi0 = min(15, int(x0 + 0.5));
const uint xi1 = min(15, int(x1 + 0.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q5_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
const uint xi0 = min(31, int(x0 + 16.5));
const uint xi1 = min(31, int(x1 + 16.5));
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
}
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
}
#endif
#if defined(DATA_A_Q5_1)
void quantize(uint dst_idx, uint src_idx)
{
float min = data_s[src_idx + 0];
float max = min;
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
const float v = data_s[src_idx + j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = (d != 0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(min);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
const uint xi0 = uint(x0 + 0.5);
const uint xi1 = uint(x1 + 0.5);
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
}
data_q[dst_idx].qh = qh;
}
#endif
#if defined(DATA_A_Q8_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0; // absolute max
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
const float v = data_s[src_idx + j];
amax = max(amax, abs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
const float x0 = data_s[src_idx + j]*id;
data_q[dst_idx].qs[j] = int8_t(round(x0));
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
if (x >= kvalues_iq4nl[15]) return 15;
int ml = 0, mu = 15;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
}
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
}
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = (d != 0.0) ? 1.0/d : 0.0;
float sumqx = 0, sumq2 = 0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
const uint xi0 = best_index(x0);
const uint xi1 = best_index(x1);
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
}
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = dst_idx_quant(idx, QUANT_K);
uint src_idx = get_aoffset() + src0_idx(idx);
quantize(dst_idx, src_idx);
}

View file

@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_
block_q2_K block; block_q2_K block;
}; };
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{ {
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d; const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1]; const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint scalesi = iqs / 16; // 0..15 const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];
uint32_t qs = bl.block.qs[qsi];
const uint scales = bl.block.scales[scalesi]; const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret; return ret;
} }
@ -157,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed16 block; block_q4_K_packed16 block;
}; };
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{ {
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1]; const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1 const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7 const uint is = (idx & 0xE0) >> 5; // 0..7
const f16vec2 loadd = bl.block.d; uvec4 v = bl128.block.q4k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc; uint32_t sc;
uint32_t mbyte; uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4); uint32_t scale0 = v.y;
uint32_t scidx1 = (is < 4) ? is : (is - 4); uint32_t scale4 = v.z;
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; uint32_t scale8 = v.w;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); uint32_t sc_lo = scale0;
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc); const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte); const float16_t m = loadd.y * float16_t(mbyte);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F; qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
qs = unpack8(qs)[idx & 1];
float16_t ret = d * float16_t(qs) - m; float16_t ret = d * float16_t(qs) - m;
@ -204,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
block_q5_K_packed16 block; block_q5_K_packed16 block;
}; };
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{ {
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1]; const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1 const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7 const uint is = (idx & 0xE0) >> 5; // 0..7
const uint32_t hm = 0x0101 << is; uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = bl.block.d; const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc; uint32_t sc;
uint32_t mbyte; uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4); uint32_t scale0 = v.y;
uint32_t scidx1 = (is < 4) ? is : (is - 4); uint32_t scale4 = v.z;
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; uint32_t scale8 = v.w;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); uint32_t sc_lo = scale0;
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc); const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte); const float16_t m = loadd.y * float16_t(mbyte);
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = qh & hm; qh = ((qh >> is) & 0x101) << 4;
qh = unpack8(qh)[idx & 1];
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F; qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1]; qs = unpack8(qs | qh)[idx & 1];
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; float16_t ret = d * (float16_t(qs)) - m;
return ret; return ret;
} }

View file

@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nb01;
uint32_t nb02; uint32_t nb02;
uint32_t nb03; uint32_t nb03;
uint32_t nb11;
uint32_t nb12; uint32_t nb12;
uint32_t nb13; uint32_t nb13;
uint32_t nb21;
uint32_t nb22; uint32_t nb22;
uint32_t nb23; uint32_t nb23;
uint32_t nb31; uint32_t nb31;
@ -146,6 +149,23 @@ void main() {
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements
uint32_t q_stride = p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q; coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;

View file

@ -54,3 +54,23 @@ uint dst_idx(uint idx) {
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
} }
uint src0_idx_quant(uint idx, uint qk) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;
}
uint dst_idx_quant(uint idx, uint qk) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
}

View file

@ -5,6 +5,80 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
barrier();
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
if (i >= num_blocks_per_row)
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block // 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16; const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16 const uint itid = tid%16; // 0...15
const uint ix = tid/16; const uint ix = tid/16;
const uint step = 8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint l0 = 2*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint nbr_par_th = num_blocks_per_row%it_size;
const uint y_idx = i * QUANT_K + y_offset; const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
vec2 d = vec2(data_a[ib0 + i].d); calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
uvec2 qs0 = uvec2(unpack8(qs0_u16));
uvec2 qs16 = uvec2(unpack8(qs16_u16));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }

View file

@ -5,6 +5,74 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
// 0, 1, 16, 17
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
if (all_threads) {
barrier();
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block // 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16; const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16 const uint itid = tid%16; // 0...15
const uint ix = tid/16; const uint ix = tid/16;
const uint itid8 = itid%8;
const uint step = 8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_im4 = v_im*4;
const uint v_in = itid - 8*v_im; // 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint32_t m = 0x01010101 << (4 * v_im);
const uint v_in = itid - step*v_im; // 0...15 or 0...7 uint32_t hm_m[4];
[[unroll]] for (uint j = 0; j < 4; ++j)
const uint8_t m = uint8_t(1 << (4 * v_im)); hm_m[j] = m << j;
const uint l0 = 2*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
const uint s_shift = 4 * v_im; const uint s_shift = v_im4 + 2*(itid8/4);
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint nbr_par_th = num_blocks_per_row%it_size;
const uint y_idx = i * QUANT_K + y_offset; const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
u8vec2 s0 = unpack8(s0_16);
u8vec2 s2 = unpack8(s2_16);
u8vec2 s4 = unpack8(s4_16);
u8vec2 s6 = unpack8(s6_16);
u8vec2 s8 = unpack8(s8_16);
u8vec2 s10 = unpack8(s10_16);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }

View file

@ -6,6 +6,86 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
const FLOAT_TYPE q4_0 = qs0_lo4.x;
const FLOAT_TYPE q4_1 = qs0_lo4.y;
const FLOAT_TYPE q4_2 = qs0_lo4.z;
const FLOAT_TYPE q4_3 = qs0_lo4.w;
const FLOAT_TYPE q4_4 = qs0_hi4.x;
const FLOAT_TYPE q4_5 = qs0_hi4.y;
const FLOAT_TYPE q4_6 = qs0_hi4.z;
const FLOAT_TYPE q4_7 = qs0_hi4.w;
const FLOAT_TYPE q4_8 = qs64_lo4.x;
const FLOAT_TYPE q4_9 = qs64_lo4.y;
const FLOAT_TYPE q4_10 = qs64_lo4.z;
const FLOAT_TYPE q4_11 = qs64_lo4.w;
const FLOAT_TYPE q4_12 = qs64_hi4.x;
const FLOAT_TYPE q4_13 = qs64_hi4.y;
const FLOAT_TYPE q4_14 = qs64_hi4.z;
const FLOAT_TYPE q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block // 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16; const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16 const uint itid = tid%16; // 0...15
const uint ix = tid/16; const uint ix = tid/16;
const uint step = 4; const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...3
const uint il = itid/step; // 0...3
const uint ir = itid - step*il; // 0...7 or 0...3
const uint n = 4; const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0; const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
const uint y1_idx = i * QUANT_K + y_offset; calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
const uint32_t sc0 = ( scale0.x & 0x3f);
const uint32_t sc1 = ( scale0.y & 0x3f);
const uint32_t sc2 = ( scale4.x & 0x3f);
const uint32_t sc3 = ( scale4.y & 0x3f);
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
const uint32_t q4_0 = qs0_lo4.x;
const uint32_t q4_1 = qs0_lo4.y;
const uint32_t q4_2 = qs0_lo4.z;
const uint32_t q4_3 = qs0_lo4.w;
const uint32_t q4_4 = qs0_hi4.x;
const uint32_t q4_5 = qs0_hi4.y;
const uint32_t q4_6 = qs0_hi4.z;
const uint32_t q4_7 = qs0_hi4.w;
const uint32_t q4_8 = qs64_lo4.x;
const uint32_t q4_9 = qs64_lo4.y;
const uint32_t q4_10 = qs64_lo4.z;
const uint32_t q4_11 = qs64_lo4.w;
const uint32_t q4_12 = qs64_hi4.x;
const uint32_t q4_13 = qs64_hi4.y;
const uint32_t q4_14 = qs64_hi4.z;
const uint32_t q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }

View file

@ -6,6 +6,118 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block // 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16; const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16 const uint itid = tid%16; // 0...15
const uint ix = tid/16; const uint ix = tid/16;
const uint il = itid/4; // 0...3 const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...7 or 0...3 const uint ir = itid - 4*il; // 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2; const uint v_in = il % 2;
@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0; const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
const uint y1_idx = i * QUANT_K + y_offset; calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
const uint32_t sc0 = ( scale0.x & 0x3f);
const uint32_t sc1 = ( scale0.y & 0x3f);
const uint32_t sc2 = ( scale4.x & 0x3f);
const uint32_t sc3 = ( scale4.y & 0x3f);
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
const uint32_t q4_0 = qs0_16_lo4.x;
const uint32_t q4_1 = qs0_16_lo4.y;
const uint32_t q4_2 = qs0_16_lo4.z;
const uint32_t q4_3 = qs0_16_lo4.w;
const uint32_t q4_4 = qs0_16_hi4.x;
const uint32_t q4_5 = qs0_16_hi4.y;
const uint32_t q4_6 = qs0_16_hi4.z;
const uint32_t q4_7 = qs0_16_hi4.w;
const uint32_t q4_8 = qs64_80_lo4.x;
const uint32_t q4_9 = qs64_80_lo4.y;
const uint32_t q4_10 = qs64_80_lo4.z;
const uint32_t q4_11 = qs64_80_lo4.w;
const uint32_t q4_12 = qs64_80_hi4.x;
const uint32_t q4_13 = qs64_80_hi4.y;
const uint32_t q4_14 = qs64_80_hi4.z;
const uint32_t q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }

View file

@ -6,7 +6,77 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
if (all_threads) {
barrier();
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
[[unroll]] for (uint l = 0; l < 4; ++l) {
sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
}
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
}
}
}
void compute_outputs(const uint first_row, const uint num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block // 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16; const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16 const uint itid = tid%16; // 0...15
const uint ix = tid/16; const uint ix = tid/16;
const uint step = 8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
const uint is = v_in / 4; const uint is = v_in / 4;
@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint s_offset = 8*v_im + is; const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint nbr_par_th = num_blocks_per_row%it_size;
const uint y_idx = i * QUANT_K + y_offset; const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
FLOAT_TYPE scales[4];
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
uvec4 q0 = uvec4(unpack8(q0_u32));
uvec4 q1 = uvec4(unpack8(q1_u32));
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) {
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
}
temp[j][n] += sum * d;
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }

View file

@ -227,6 +227,11 @@ struct block_q4_K_packed32
uint32_t qs[QUANT_K_Q4_K/2/4]; uint32_t qs[QUANT_K_Q4_K/2/4];
}; };
struct block_q4_K_packed128
{
uvec4 q4k[9];
};
#if defined(DATA_A_Q4_K) #if defined(DATA_A_Q4_K)
#define QUANT_K QUANT_K_Q4_K #define QUANT_K QUANT_K_Q4_K
#define A_TYPE block_q4_K #define A_TYPE block_q4_K
@ -252,6 +257,11 @@ struct block_q5_K_packed16
uint16_t qs[QUANT_K_Q5_K/2/2]; uint16_t qs[QUANT_K_Q5_K/2/2];
}; };
struct block_q5_K_packed128
{
uvec4 q5k[11];
};
#if defined(DATA_A_Q5_K) #if defined(DATA_A_Q5_K)
#define QUANT_K QUANT_K_Q5_K #define QUANT_K QUANT_K_Q5_K
#define A_TYPE block_q5_K #define A_TYPE block_q5_K

View file

@ -30,8 +30,6 @@
#include <fcntl.h> #include <fcntl.h>
#endif #endif
#include <vulkan/vulkan_core.h>
#define ASYNCIO_CONCURRENCY 64 #define ASYNCIO_CONCURRENCY 64
std::mutex lock; std::mutex lock;
@ -419,6 +417,11 @@ void process_shaders() {
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});

View file

@ -3450,12 +3450,14 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
} }
// ggml_soft_max_back // ggml_soft_max_ext_back
static struct ggml_tensor * ggml_soft_max_back_impl( static struct ggml_tensor * ggml_soft_max_ext_back_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
float scale,
float max_bias,
bool inplace) { bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@ -3463,21 +3465,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
result->src[0] = a; result->src[0] = a;
result->src[1] = b; result->src[1] = b;
memcpy((float *) result->op_params + 0, &scale, sizeof(float));
memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
return result; return result;
} }
struct ggml_tensor * ggml_soft_max_back( struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b,
return ggml_soft_max_back_impl(ctx, a, b, false); float scale,
float max_bias) {
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
} }
struct ggml_tensor * ggml_soft_max_back_inplace( struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b,
return ggml_soft_max_back_impl(ctx, a, b, true); float scale,
float max_bias) {
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
} }
// ggml_rope // ggml_rope
@ -3695,7 +3704,7 @@ void ggml_rope_yarn_corr_dims(
// ggml_rope_back // ggml_rope_back
struct ggml_tensor * ggml_rope_back( struct ggml_tensor * ggml_rope_ext_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
@ -3709,29 +3718,32 @@ struct ggml_tensor * ggml_rope_back(
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
GGML_ASSERT(ggml_is_vector(b)); struct ggml_tensor * result = ggml_rope_ext(
GGML_ASSERT(b->type == GGML_TYPE_I32); ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
GGML_ASSERT(a->ne[2] == b->ne[0]); result->op = GGML_OP_ROPE_BACK;
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result; return result;
} }
struct ggml_tensor * ggml_rope_multi_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
struct ggml_tensor * result = ggml_rope_multi(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
result->op = GGML_OP_ROPE_BACK;
return result;
}
// ggml_clamp // ggml_clamp
struct ggml_tensor * ggml_clamp( struct ggml_tensor * ggml_clamp(
@ -5073,10 +5085,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c) { struct ggml_tensor * c) {
GGML_ASSERT(ggml_are_same_shape(a, b)); GGML_ASSERT(ggml_is_scalar(a));
GGML_ASSERT(ggml_is_scalar(c)); GGML_ASSERT(ggml_are_same_shape(b, c));
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
result->src[0] = a; result->src[0] = a;
@ -5255,7 +5267,7 @@ static void ggml_sub_or_set(
} }
static void ggml_compute_backward( static void ggml_compute_backward(
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
struct ggml_tensor * tensor = cgraph->nodes[i]; struct ggml_tensor * tensor = cgraph->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
@ -5399,7 +5411,7 @@ static void ggml_compute_backward(
if (src0_needs_grads) { if (src0_needs_grads) {
float eps; float eps;
memcpy(&eps, tensor->op_params, sizeof(float)); memcpy(&eps, tensor->op_params, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
} }
} break; } break;
case GGML_OP_MUL_MAT: { case GGML_OP_MUL_MAT: {
@ -5582,7 +5594,13 @@ static void ggml_compute_backward(
} break; } break;
case GGML_OP_SOFT_MAX: { case GGML_OP_SOFT_MAX: {
if (src0_needs_grads) { if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
} }
GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
} break; } break;
@ -5594,6 +5612,7 @@ static void ggml_compute_backward(
//const int n_ctx = ((int32_t *) tensor->op_params)[3]; //const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4] = {0, 0, 0, 0};
memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@ -5601,10 +5620,14 @@ static void ggml_compute_backward(
memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
memcpy(&sections, tensor->op_params + 11, sizeof(sections));
ggml_add_or_set(ctx, cgraph, isrc0, struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
} }
GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
} break; } break;
@ -5618,7 +5641,7 @@ static void ggml_compute_backward(
const int32_t d1 = ggml_get_op_params_i32(tensor, 5); const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
} }
} break; } break;
case GGML_OP_POOL_2D: { case GGML_OP_POOL_2D: {
@ -5661,7 +5684,7 @@ static void ggml_compute_backward(
} break; } break;
case GGML_UNARY_OP_SILU: { case GGML_UNARY_OP_SILU: {
if (src0_needs_grads) { if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
} }
} break; } break;
case GGML_UNARY_OP_EXP: { case GGML_UNARY_OP_EXP: {
@ -5678,7 +5701,7 @@ static void ggml_compute_backward(
} break; } break;
case GGML_OP_CROSS_ENTROPY_LOSS: { case GGML_OP_CROSS_ENTROPY_LOSS: {
if (src0_needs_grads) { if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
} }
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break; } break;

View file

@ -648,6 +648,10 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ok = ok && data != nullptr; ok = ok && data != nullptr;
if (ok) {
ggml_set_name(data, "GGUF tensor data binary blob");
}
// read the binary blob with the tensor data // read the binary blob with the tensor data
ok = ok && gr.read(data->data, ctx->size); ok = ok && gr.read(data->data, ctx->size);

View file

@ -288,9 +288,6 @@ extern "C" {
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split; const float * tensor_split;
// comma separated list of RPC servers to use for offloading
const char * rpc_servers;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues. // If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted. // If it returns false, model loading is immediately aborted.
@ -418,10 +415,20 @@ extern "C" {
struct llama_model_params params), struct llama_model_params params),
"use llama_model_load_from_file instead"); "use llama_model_load_from_file instead");
// Load the model from a file
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
// If the split file name does not follow this pattern, use llama_model_load_from_splits
LLAMA_API struct llama_model * llama_model_load_from_file( LLAMA_API struct llama_model * llama_model_load_from_file(
const char * path_model, const char * path_model,
struct llama_model_params params); struct llama_model_params params);
// Load the model from multiple splits (support custom naming scheme)
// The paths must be in the correct order
LLAMA_API struct llama_model * llama_model_load_from_splits(
const char ** paths,
size_t n_paths,
struct llama_model_params params);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead"); "use llama_model_free instead");
@ -951,7 +958,7 @@ extern "C" {
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead"); DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead");
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");

View file

@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
} }
} }
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
std::vector<std::string> paths;
std::string split_prefix;
std::vector<char> buf(llama_path_max(), 0);
{
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
if (!ret) {
throw std::runtime_error(format("invalid split file name: %s", path.c_str()));
}
split_prefix = std::string(buf.data(), ret);
}
if (split_prefix.empty()) {
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
}
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
paths.push_back(std::string(buf.data(), ret));
}
return paths;
}
namespace GGUFMeta { namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)> template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
struct GKV_Base_Type { struct GKV_Base_Type {
@ -413,7 +440,12 @@ namespace GGUFMeta {
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { llama_model_loader::llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits,
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p) {
int trace = 0; int trace = 0;
if (getenv("LLAMA_TRACE")) { if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE")); trace = atoi(getenv("LLAMA_TRACE"));
@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
} }
} }
// Load the main GGUF
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
/*.no_alloc = */ true, /*.no_alloc = */ true,
@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
// Load additional GGML contexts // Load additional GGML contexts
if (n_split > 1) { if (n_split > 1) {
// make sure the main file is loaded first
uint16_t idx = 0; uint16_t idx = 0;
get_key(llm_kv(LLM_KV_SPLIT_NO), idx); const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
get_key(kv_split_no, idx);
if (idx != 0) { if (idx != 0) {
throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
} }
std::vector<char> split_prefix(llama_path_max(), 0); // generate list of splits if needed
if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) { if (splits.empty()) {
throw std::runtime_error(format("invalid split file: %s", fname.c_str())); splits = llama_get_list_splits(fname, idx, n_split);
}
// in case user give a custom list of splits, check if it matches the expected number
if (n_split != (uint16_t)splits.size()) {
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
} }
if (trace > 0) { if (trace > 0) {
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
} }
std::vector<char> split_path(llama_path_max(), 0); // load other splits
for (idx = 1; idx < n_split; idx++) { for (idx = 1; idx < n_split; idx++) {
llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split); const char * fname_split = splits[idx].c_str();
struct gguf_init_params split_params = { struct gguf_init_params split_params = {
/*.no_alloc = */ true, /*.no_alloc = */ true,
/*.ctx = */ &ctx, /*.ctx = */ &ctx,
}; };
gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) }; gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
if (!ctx_gguf) { if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data())); throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
} }
files.emplace_back(new llama_file(split_path.data(), "rb")); // check idx
{
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
if (kid < 0) {
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
}
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
if (idx_gguf != idx) {
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
}
}
files.emplace_back(new llama_file(fname_split, "rb"));
contexts.emplace_back(ctx); contexts.emplace_back(ctx);
// Save tensors data offset info of the shard. // Save tensors data offset info of the shard.

View file

@ -90,7 +90,12 @@ struct llama_model_loader {
size_t size_data = 0; size_t size_data = 0;
std::vector<std::pair<size_t, size_t>> mmaps_used; std::vector<std::pair<size_t, size_t>> mmaps_used;
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p); llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p);
template<typename T> template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type typename std::enable_if<std::is_integral<T>::value, bool>::type

View file

@ -2203,6 +2203,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
} }
} break; } break;
case LLM_ARCH_PHIMOE:
{
const int64_t n_embd_head = n_embd / n_head;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}
} break;
case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO:
{ {
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -3717,7 +3761,6 @@ struct llama_model_params llama_model_default_params() {
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr, /*.tensor_split =*/ nullptr,
/*.rpc_servers =*/ nullptr,
/*.progress_callback =*/ nullptr, /*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr, /*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr, /*.kv_overrides =*/ nullptr,

View file

@ -323,8 +323,6 @@ struct llama_model {
// gguf metadata // gguf metadata
std::unordered_map<std::string, std::string> gguf_kv; std::unordered_map<std::string, std::string> gguf_kv;
std::vector<std::string> rpc_servers;
// list of devices used in this model // list of devices used in this model
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;

View file

@ -526,7 +526,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
kv_overrides = v->data(); kv_overrides = v->data();
} }
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
ml.init_mappings(false); // no prefetching ml.init_mappings(false); // no prefetching
llama_model model(llama_model_default_params()); llama_model model(llama_model_default_params());

View file

@ -439,7 +439,7 @@ struct llm_tokenizer_bpe_session {
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__); "Are you sure this is what you want?\n", __FUNCTION__);
} }
if (vocab.get_add_bos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
LLAMA_LOG_WARN( LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt " "%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "

View file

@ -31,7 +31,7 @@
#endif #endif
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
// loading time will be recalculated after the first eval, so // loading time will be recalculated after the first eval, so
// we take page faults deferred by mmap() into consideration // we take page faults deferred by mmap() into consideration
model.t_load_us = 0; model.t_load_us = 0;
@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
model.t_start_us = tm.t_start_us; model.t_start_us = tm.t_start_us;
try { try {
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
ml.print_info(); ml.print_info();
@ -4642,7 +4642,7 @@ struct llm_build_context {
0); 0);
cb(v_states, "v_states", il); cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
q_pe = ggml_rope_ext( q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, rope_factors, ctx0, q_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -4651,7 +4651,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il); cb(q_pe, "q_pe", il);
// shared RoPE key // shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
k_pe = ggml_rope_ext( k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, rope_factors, ctx0, k_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -6496,7 +6496,7 @@ struct llm_build_context {
0); 0);
cb(v_states, "v_states", il); cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
q_pe = ggml_rope_ext( q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr, ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -6505,7 +6505,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il); cb(q_pe, "q_pe", il);
// shared RoPE key // shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
k_pe = ggml_rope_ext( k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr, ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -9374,14 +9374,9 @@ int64_t llama_time_us(void) {
return ggml_time_us(); return ggml_time_us();
} }
struct llama_model * llama_load_model_from_file( static struct llama_model * llama_model_load_from_file_impl(
const char * path_model, const std::string & path_model,
struct llama_model_params params) { std::vector<std::string> & splits,
return llama_model_load_from_file(path_model, params);
}
struct llama_model * llama_model_load_from_file(
const char * path_model,
struct llama_model_params params) { struct llama_model_params params) {
ggml_time_init(); ggml_time_init();
@ -9404,47 +9399,6 @@ struct llama_model * llama_model_load_from_file(
}; };
} }
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
// split the servers set them into model->rpc_servers
std::string servers(params.rpc_servers);
size_t pos = 0;
while ((pos = servers.find(',')) != std::string::npos) {
std::string server = servers.substr(0, pos);
model->rpc_servers.push_back(server);
servers.erase(0, pos + 1);
}
model->rpc_servers.push_back(servers);
}
// add RPC devices
if (!model->rpc_servers.empty()) {
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
llama_model_free(model);
return nullptr;
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
llama_model_free(model);
return nullptr;
}
for (const std::string & server : model->rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
model->devices.push_back(dev);
} else {
LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
llama_model_free(model);
return nullptr;
}
}
}
// create list of devices to use with this model // create list of devices to use with this model
if (params.devices) { if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
@ -9485,7 +9439,7 @@ struct llama_model * llama_model_load_from_file(
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024); LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
} }
const int status = llama_model_load(path_model, *model, params); const int status = llama_model_load(path_model, splits, *model, params);
GGML_ASSERT(status <= 0); GGML_ASSERT(status <= 0);
if (status < 0) { if (status < 0) {
if (status == -1) { if (status == -1) {
@ -9501,6 +9455,35 @@ struct llama_model * llama_model_load_from_file(
return model; return model;
} }
// deprecated
struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params) {
return llama_model_load_from_file(path_model, params);
}
struct llama_model * llama_model_load_from_file(
const char * path_model,
struct llama_model_params params) {
std::vector<std::string> splits = {};
return llama_model_load_from_file_impl(path_model, splits, params);
}
struct llama_model * llama_model_load_from_splits(
const char ** paths,
size_t n_paths,
struct llama_model_params params) {
std::vector<std::string> splits;
if (n_paths == 0) {
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
return nullptr;
}
for (size_t i = 0; i < n_paths; ++i) {
splits.push_back(paths[i]);
}
return llama_model_load_from_file_impl(splits.front(), splits, params);
}
struct llama_context * llama_init_from_model( struct llama_context * llama_init_from_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params) { struct llama_context_params params) {

View file

@ -7,18 +7,17 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <codecvt>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <locale>
#include <map> #include <map>
#include <regex> #include <regex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <locale>
#include <codecvt>
size_t unicode_len_utf8(char src) { size_t unicode_len_utf8(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };

View file

@ -1,3 +1,5 @@
llama_add_compile_flags()
function(llama_test target) function(llama_test target)
include(CMakeParseArguments) include(CMakeParseArguments)
set(options) set(options)

View file

@ -780,7 +780,7 @@ struct test_case {
} }
} }
if (!any_params) { if (!any_params) {
printf("not supported [%s] \n", op_name); printf("not supported [%s] \n", op_desc(out).c_str());
supported = false; supported = false;
} }
if (!supported) { if (!supported) {
@ -1130,6 +1130,59 @@ struct test_get_rows : public test_case {
} }
}; };
// GGML_OP_GET_ROWS_BACK
struct test_get_rows_back : public test_case {
const ggml_type type;
const int n; // cols
const int m; // rows
const int r; // rows to get
const int b; // batch size
const bool v; // view (non-contiguous src1)
std::string vars() override {
return VARS_TO_STR6(type, n, m, r, b, v);
}
test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
: type(type), n(n), m(m), r(r), b(b), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
ggml_set_name(in_forward, "in_forward");
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
ggml_set_name(rows, "rows");
if (v) {
rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
ggml_set_name(rows, "view_of_rows");
}
ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// rows
std::vector<int> data(r*b);
for (int i = 0; i < r*b; i++) {
data[i] = rand() % m;
}
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_ARGMAX // GGML_OP_ARGMAX
struct test_argmax : public test_case { struct test_argmax : public test_case {
const ggml_type type; const ggml_type type;
@ -1531,6 +1584,39 @@ struct test_scale : public test_case {
} }
}; };
// GGML_OP_SILU_BACK
struct test_silu_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_silu_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_silu_back(ctx, a, grad);
ggml_set_name(out, "out");
return out;
}
bool grad_precise() override {
return true;
}
};
// GGML_OP_NORM // GGML_OP_NORM
struct test_norm : public test_case { struct test_norm : public test_case {
const ggml_type type; const ggml_type type;
@ -1583,11 +1669,56 @@ struct test_rms_norm : public test_case {
return out; return out;
} }
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
float grad_eps() override {
return 1.0f;
}
bool grad_precise() override { bool grad_precise() override {
return true; return true;
} }
}; };
// GGML_OP_RMS_NORM_BACK
struct test_rms_norm_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
};
// GGML_OP_SSM_CONV // GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case { struct test_ssm_conv : public test_case {
const ggml_type type; const ggml_type type;
@ -1855,10 +1986,11 @@ struct test_out_prod : public test_case {
const int64_t n; const int64_t n;
const int64_t k; const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4 const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const bool trans_b; const bool trans_b;
std::string vars() override { std::string vars() override {
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b); return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
} }
double max_nmse_err() override { double max_nmse_err() override {
@ -1868,8 +2000,9 @@ struct test_out_prod : public test_case {
test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32, int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10}, std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
bool trans_b = false) bool trans_b = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {} : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]); ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
@ -1877,10 +2010,10 @@ struct test_out_prod : public test_case {
ggml_tensor * b; ggml_tensor * b;
if (trans_b) { if (trans_b) {
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
b = ggml_transpose(ctx, b); b = ggml_transpose(ctx, b);
} else { } else {
b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
} }
ggml_set_name(b, "b"); ggml_set_name(b, "b");
@ -2191,8 +2324,38 @@ struct test_soft_max : public test_case {
} }
}; };
// GGML_OP_SOFT_MAX_BACK
struct test_soft_max_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float scale;
const float max_bias;
// GGML_OP_ROPE std::string vars() override {
return VARS_TO_STR4(type, ne, scale, max_bias);
}
test_soft_max_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
float scale = 1.0f,
float max_bias = 0.0f)
: type(type), ne(ne), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_ROPE + GGML_OP_ROPE_BACK
struct test_rope : public test_case { struct test_rope : public test_case {
const ggml_type type; const ggml_type type;
const std::array<int64_t, 4> ne_a; const std::array<int64_t, 4> ne_a;
@ -2204,29 +2367,36 @@ struct test_rope : public test_case {
float af; // attn_factor float af; // attn_factor
bool ff; bool ff;
int v; // view (1 : non-contiguous a) int v; // view (1 : non-contiguous a)
bool forward;
std::string vars() override { std::string vars() override {
// forward can be inferred from the op, does not need to be printed
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
} }
test_rope(ggml_type type = GGML_TYPE_F32, test_rope(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {10, 5, 3, 1}, std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0) int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {} float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a; ggml_tensor * a;
if (v & 1) { if (v & 1) {
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data()); a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a); if (forward) {
ggml_set_param(ctx, a);
}
ggml_set_name(a, "a"); ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a"); ggml_set_name(a, "view_of_a");
} else { } else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data()); a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(ctx, a); if (forward) {
ggml_set_param(ctx, a);
}
ggml_set_name(a, "a"); ggml_set_name(a, "a");
} }
@ -2252,14 +2422,26 @@ struct test_rope : public test_case {
if (is_vision) { if (is_vision) {
GGML_ASSERT(n_dims/4 > 0); GGML_ASSERT(n_dims/4 > 0);
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} else { } else {
GGML_ASSERT(n_dims/3 > 0); GGML_ASSERT(n_dims/3 > 0);
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} }
} else { } else {
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (forward) {
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} }
ggml_set_name(out, "out"); ggml_set_name(out, "out");
@ -2864,9 +3046,10 @@ struct test_flash_attn_ext : public test_case {
const float logit_softcap; // Gemma 2 const float logit_softcap; // Gemma 2
const ggml_type type_KV; const ggml_type type_KV;
std::array<int32_t, 4> permute;
std::string vars() override { std::string vars() override {
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV); return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
} }
double max_nmse_err() override { double max_nmse_err() override {
@ -2881,19 +3064,33 @@ struct test_flash_attn_ext : public test_case {
} }
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1); auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
int64_t ne_perm[4];
for (int i = 0; i < 4; ++i) {
ne_perm[permute[i]] = ne[i];
}
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
}
return t;
};
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
ggml_set_name(q, "q"); ggml_set_name(q, "q");
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
ggml_tensor * m = nullptr; ggml_tensor * m = nullptr;
@ -2961,6 +3158,40 @@ struct test_cross_entropy_loss : public test_case {
} }
}; };
// GGML_OP_CROSS_ENTROPY_LOSS_BACK
struct test_cross_entropy_loss_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name(grad, "grad");
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(logits, "logits");
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(labels, "labels");
// Ensure labels add up to 1:
labels = ggml_soft_max(ctx, labels);
ggml_set_name(labels, "labels_normalized");
ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_OPT_STEP_ADAMW // GGML_OP_OPT_STEP_ADAMW
struct test_opt_step_adamw : public test_case { struct test_opt_step_adamw : public test_case {
const ggml_type type; const ggml_type type;
@ -3460,6 +3691,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
}
}
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
}
for (ggml_type type_input : {GGML_TYPE_F32}) { for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) { for (int k0 : {1, 3}) {
@ -3582,6 +3823,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
} }
} }
for (ggml_type type_dst : {GGML_TYPE_F32}) {
for (ggml_type type_src : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
}
}
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
@ -3638,10 +3885,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_silu_back());
for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) { for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
} }
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
@ -3781,22 +4030,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1})); for (int n : {1, 16}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); for (int k : {1, 16}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); for (int bs2 : {1, 3}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); for (int bs3 : {1, 3}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); for (int nr2 : {1, 2}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); for (int nr3 : {1, 2}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
}
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1})); }
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true)); }
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); }
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); }
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); }
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
} }
} }
@ -3839,12 +4085,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
} }
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
{ for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
}
}
}
}
for (bool fw : {true, false}) { // fw == forward
bool all = true; bool all = true;
for (float v : { 0, 1 }) { for (float v : { 0, 1 }) {
@ -3853,29 +4110,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float af : { 1.0f, 1.4245f }) { for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors for (bool ff : {false, true}) { // freq_factors
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
} }
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
} }
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT) test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
} }
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
} }
} }
@ -3925,6 +4182,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 35, }) { for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
}
} }
} }
} }
@ -3934,7 +4195,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
test_cases.emplace_back(new test_cross_entropy_loss()); test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
// these tests are disabled to save execution time, but they can be handy for debugging // these tests are disabled to save execution time, but they can be handy for debugging

View file

@ -48,7 +48,7 @@ enum handcrafted_file_type {
HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data, HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data,
}; };
std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) { static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
switch (hft) { switch (hft) {
case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC"; case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC";
case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1"; case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1";
@ -99,7 +99,7 @@ static bool expect_context_not_null(const enum handcrafted_file_type hft) {
typedef std::pair<enum ggml_type, std::array<int64_t, GGML_MAX_DIMS>> tensor_config_t; typedef std::pair<enum ggml_type, std::array<int64_t, GGML_MAX_DIMS>> tensor_config_t;
std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) { static std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) {
std::vector<tensor_config_t> tensor_configs; std::vector<tensor_config_t> tensor_configs;
tensor_configs.reserve(100); tensor_configs.reserve(100);
@ -122,7 +122,7 @@ std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) {
return tensor_configs; return tensor_configs;
} }
std::vector<std::pair<enum gguf_type, enum gguf_type>> get_kv_types(std::mt19937 rng) { static std::vector<std::pair<enum gguf_type, enum gguf_type>> get_kv_types(std::mt19937 rng) {
std::vector<std::pair<enum gguf_type, enum gguf_type>> kv_types; std::vector<std::pair<enum gguf_type, enum gguf_type>> kv_types;
kv_types.reserve(100); kv_types.reserve(100);
@ -626,8 +626,6 @@ static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const u
bool ok = true; bool ok = true;
const uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
for (int i = 0; i < int(tensor_configs.size()); ++i) { for (int i = 0; i < int(tensor_configs.size()); ++i) {
const ggml_type type = tensor_configs[i].first; const ggml_type type = tensor_configs[i].first;
const std::array<int64_t, GGML_MAX_DIMS> shape = tensor_configs[i].second; const std::array<int64_t, GGML_MAX_DIMS> shape = tensor_configs[i].second;
@ -866,13 +864,13 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
case GGUF_TYPE_COUNT: case GGUF_TYPE_COUNT:
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} break; }
} }
} break; } break;
case GGUF_TYPE_COUNT: case GGUF_TYPE_COUNT:
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} break; }
} }
} }
@ -938,7 +936,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
} }
if (type == GGUF_TYPE_ARRAY) { if (type == GGUF_TYPE_ARRAY) {
const int arr_n = gguf_get_arr_n(ctx, id); const size_t arr_n = gguf_get_arr_n(ctx, id);
if (arr_n != gguf_get_arr_n(other, idx_other)) { if (arr_n != gguf_get_arr_n(other, idx_other)) {
ok = false; ok = false;
continue; continue;
@ -953,7 +951,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
if (type_arr == GGUF_TYPE_BOOL) { if (type_arr == GGUF_TYPE_BOOL) {
const int8_t * data = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx, id)); const int8_t * data = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx, id));
const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other)); const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other));
for (int arr_i = 0; arr_i < arr_n; ++arr_i) { for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
if (bool(data[arr_i]) != bool(data_other[arr_i])) { if (bool(data[arr_i]) != bool(data_other[arr_i])) {
ok = false; ok = false;
} }
@ -962,7 +960,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
} }
if (type_arr == GGUF_TYPE_STRING) { if (type_arr == GGUF_TYPE_STRING) {
for (int arr_i = 0; arr_i < arr_n; ++arr_i) { for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
const std::string str = gguf_get_arr_str(ctx, id, arr_i); const std::string str = gguf_get_arr_str(ctx, id, arr_i);
const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i); const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i);
if (str != str_other) { if (str != str_other) {
@ -1033,6 +1031,12 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
struct ggml_tensor * t_orig = ggml_get_first_tensor(orig); struct ggml_tensor * t_orig = ggml_get_first_tensor(orig);
struct ggml_tensor * t_read = ggml_get_first_tensor(read); struct ggml_tensor * t_read = ggml_get_first_tensor(read);
if (std::string(t_read->name) != "GGUF tensor data binary blob") {
return false;
}
t_read = ggml_get_next_tensor(read, t_read);
while (t_orig) { while (t_orig) {
if (!t_read) { if (!t_read) {
ok = false; ok = false;
@ -1051,13 +1055,13 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
} }
t_orig = ggml_get_next_tensor(orig, t_orig); t_orig = ggml_get_next_tensor(orig, t_orig);
t_read = ggml_get_next_tensor(orig, t_read); t_read = ggml_get_next_tensor(read, t_read);
} }
if (t_read) { if (t_read) {
ok = false; ok = false;
} }
return true; return ok;
} }
static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) { static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {

View file

@ -80,18 +80,18 @@ run_conversion_and_inference_lora() {
# Run inference # Run inference
echo -e "\n\n---------------------------\n\n" echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..." echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ OUTPUT_BASE=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
-p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0) -p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n" echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..." echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \ OUTPUT_LORA_HOT=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \ --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n" echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..." echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \ OUTPUT_LORA_MERGED=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0) -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
# Remove any initial white space # Remove any initial white space

View file

@ -144,7 +144,6 @@ static void test_penalties(
sampler_tester tester(probs, probs_expected); sampler_tester tester(probs, probs_expected);
const size_t n_vocab = probs.size();
auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
for (size_t i = 0; i < last_tokens.size(); i++) { for (size_t i = 0; i < last_tokens.size(); i++) {