Merge remote-tracking branch 'origin/master' into sl/remove-task-type

This commit is contained in:
slaren 2024-06-23 19:22:01 +02:00
commit 9b00a7eba5
47 changed files with 23142 additions and 21240 deletions

View file

@ -28,4 +28,5 @@ indent_size = 2
indent_style = tab indent_style = tab
[examples/cvector-generator/*.txt] [examples/cvector-generator/*.txt]
trim_trailing_whitespace = unset
insert_final_newline = unset insert_final_newline = unset

View file

@ -30,7 +30,7 @@ jobs:
strategy: strategy:
matrix: matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED] sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
build_type: [RelWithDebInfo] build_type: [RelWithDebInfo]
include: include:
- build_type: Release - build_type: Release

113
.gitignore vendored
View file

@ -1,90 +1,123 @@
*.o # Extensions
*.a *.a
*.so *.bat
*.bin
*.dll
*.dot
*.etag
*.exe
*.gcda
*.gcno
*.gcov
*.gguf *.gguf
*.gguf.json *.gguf.json
*.bin
*.exe
*.dll
*.log
*.gcov
*.gcno
*.gcda
*.dot
*.bat
*.tmp
*.metallib
*.etag
*.lastModified *.lastModified
.DS_Store *.log
.build/ *.metallib
*.o
*.so
*.tmp
# IDE / OS
.cache/ .cache/
.ccls-cache/ .ccls-cache/
.direnv/ .direnv/
.DS_Store
.envrc .envrc
.idea/
.swiftpm .swiftpm
.venv
.clang-tidy
.vs/ .vs/
.vscode/ .vscode/
.idea/ nppBackup
ggml-metal-embed.metal
lcov-report/ # Coverage
gcovr-report/ gcovr-report/
lcov-report/
# Build Artifacts
tags tags
.build/
build* build*
!build-info.cmake
!build-info.cpp.in
!build-info.sh
!build.zig !build.zig
cmake-build-* /libllama.so
/llama-*
android-ndk-* android-ndk-*
arm_neon.h
cmake-build-*
CMakeSettings.json
compile_commands.json
ggml-metal-embed.metal
llama-batched-swift
out/ out/
tmp/ tmp/
# CI
!.github/workflows/*.yml
# Models
models/* models/*
models-mnt models-mnt
!models/.editorconfig
!models/ggml-vocab-*.gguf*
/Pipfile # Zig
/libllama.so
/llama-*
llama-batched-swift
/common/build-info.cpp
arm_neon.h
compile_commands.json
CMakeSettings.json
__pycache__
dist
zig-out/ zig-out/
zig-cache/ zig-cache/
# Logs
ppl-*.txt ppl-*.txt
qnt-*.txt qnt-*.txt
perf-*.txt perf-*.txt
# Examples
examples/jeopardy/results.txt examples/jeopardy/results.txt
examples/server/*.css.hpp
examples/server/*.html.hpp examples/server/*.html.hpp
examples/server/*.js.hpp examples/server/*.js.hpp
examples/server/*.mjs.hpp examples/server/*.mjs.hpp
examples/server/*.css.hpp !build_64.sh
!examples/*.bat
!examples/*/*.kts
!examples/*/*/*.kts
!examples/sycl/*.bat
!examples/sycl/*.sh
# Python
__pycache__
.venv
/Pipfile
dist
poetry.lock poetry.lock
poetry.toml poetry.toml
nppBackup
# Test binaries # Test binaries
/tests/test-grammar-parser /tests/test-backend-ops
/tests/test-llama-grammar
/tests/test-double-float /tests/test-double-float
/tests/test-grad0 /tests/test-grad0
/tests/test-grammar-parser
/tests/test-llama-grammar
/tests/test-opt /tests/test-opt
/tests/test-quantize-fns /tests/test-quantize-fns
/tests/test-quantize-perf /tests/test-quantize-perf
/tests/test-rope
/tests/test-sampling /tests/test-sampling
/tests/test-tokenizer-0 /tests/test-tokenizer-0
/tests/test-tokenizer-1-spm
/tests/test-tokenizer-1-bpe /tests/test-tokenizer-1-bpe
/tests/test-rope /tests/test-tokenizer-1-spm
/tests/test-backend-ops
# Scripts
!/scripts/install-oneapi.bat

View file

@ -665,6 +665,7 @@ if (LLAMA_SYCL)
#todo: AOT #todo: AOT
find_package(IntelSYCL REQUIRED) find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
message(STATUS "SYCL found") message(STATUS "SYCL found")
@ -679,11 +680,9 @@ if (LLAMA_SYCL)
endif() endif()
add_compile_options(-I./) #include DPCT add_compile_options(-I./) #include DPCT
add_compile_options(-I/${SYCL_INCLUDE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib")
if (LLAMA_SYCL_TARGET STREQUAL "NVIDIA") if (LLAMA_SYCL_TARGET STREQUAL "NVIDIA")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
endif() endif()
@ -693,8 +692,10 @@ if (LLAMA_SYCL)
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
if (WIN32) if (WIN32)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl sycl7 OpenCL mkl_sycl_blas_dll.lib mkl_intel_ilp64_dll.lib mkl_sequential_dll.lib mkl_core_dll.lib) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
else() else()
add_compile_options(-I/${SYCL_INCLUDE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib")
if (LLAMA_SYCL_TARGET STREQUAL "INTEL") if (LLAMA_SYCL_TARGET STREQUAL "INTEL")
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
elseif (LLAMA_SYCL_TARGET STREQUAL "NVIDIA") elseif (LLAMA_SYCL_TARGET STREQUAL "NVIDIA")

View file

@ -11,9 +11,21 @@
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
} }
}, },
{
"name": "sycl-base",
"hidden": true,
"generator": "Ninja",
"binaryDir": "${sourceDir}/build-${presetName}",
"cacheVariables": {
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"CMAKE_CXX_COMPILER": "icx",
"LLAMA_SYCL": "ON",
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
}
},
{ "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } },
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "LLAMA_STATIC": "ON" } }, { "name": "static", "hidden": true, "cacheVariables": { "LLAMA_STATIC": "ON" } },
{ {
@ -35,15 +47,18 @@
}, },
{ "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, { "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "release" ] }, { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "release", "static" ] }, { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] },
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, { "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "release" ] }, { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "release", "static" ] }, { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] }, { "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] },
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "release" ] }, { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "release", "static" ] } { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
] ]
} }

View file

@ -1051,7 +1051,7 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS) tests/test-grammar-integration: tests/test-grammar-integration.cpp json-schema-to-grammar.o ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -410,15 +410,9 @@ Output (example):
4. Install build tools 4. Install build tools
a. Download & install cmake for Windows: https://cmake.org/download/ a. Download & install cmake for Windows: https://cmake.org/download/ (CMake can also be installed from Visual Studio Installer)
b. The new Visual Studio will install Ninja as default. (If not, please install it manually: https://ninja-build.org/)
b. Download & install mingw-w64 make for Windows provided by w64devkit
- Download the 1.19.0 version of [w64devkit](https://github.com/skeeto/w64devkit/releases/download/v1.19.0/w64devkit-1.19.0.zip).
- Extract `w64devkit` on your pc.
- Add the **bin** folder path in the Windows system PATH environment (for e.g. `C:\xxx\w64devkit\bin\`).
### II. Build llama.cpp ### II. Build llama.cpp
@ -428,10 +422,10 @@ On the oneAPI command line window, step into the llama.cpp main directory and ru
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
# Option 1: Use FP32 (recommended for better performance in most cases) # Option 1: Use FP32 (recommended for better performance in most cases)
cmake -B build -G "MinGW Makefiles" -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release cmake -B build -G "Ninja" -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
# Option 2: Or FP16 # Option 2: Or FP16
cmake -B build -G "MinGW Makefiles" -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON cmake -B build -G "Ninja" -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
cmake --build build --config Release -j cmake --build build --config Release -j
``` ```
@ -441,9 +435,23 @@ Otherwise, run the `win-build-sycl.bat` wrapper which encapsulates the former in
.\examples\sycl\win-build-sycl.bat .\examples\sycl\win-build-sycl.bat
``` ```
Or, use CMake presets to build:
```sh
cmake --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli
cmake -DLLAMA_SYCL_F16=ON --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli
cmake --preset x64-windows-sycl-debug
cmake --build build-x64-windows-sycl-debug -j --target llama-cli
```
Or, you can use Visual Studio to open llama.cpp folder as a CMake project. Choose the sycl CMake presets (`x64-windows-sycl-release` or `x64-windows-sycl-debug`) before you compile the project.
*Notes:* *Notes:*
- By default, calling `make` will build all target binary files. In case of a minimal experimental setup, the user can build the inference executable only through `make llama-cli`. - In case of a minimal experimental setup, the user can build the inference executable only through `cmake --build build --config Release -j --target llama-cli`.
### III. Run the inference ### III. Run the inference

View file

@ -6,7 +6,6 @@
#include "llama.h" #include "llama.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <cinttypes> #include <cinttypes>
#include <cmath> #include <cmath>
#include <codecvt> #include <codecvt>
@ -542,6 +541,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else { invalid_param = true; } else { invalid_param = true; }
return true; return true;
} }
@ -1870,6 +1870,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "backend" }); options.push_back({ "backend" });
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
if (llama_supports_mlock()) { if (llama_supports_mlock()) {
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
} }
@ -1988,8 +1989,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "cvector", " --completions-file FNAME", options.push_back({ "cvector", " --completions-file FNAME",
"completions file (default: '%s')", params.cvector_completions_file.c_str() }); "completions file (default: '%s')", params.cvector_completions_file.c_str() });
options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions }); options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions });
options.push_back({ "cvector", " --batch-pca N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch });
options.push_back({ "cvector", " --iter-pca N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations });
printf("usage: %s [options]\n", argv[0]); printf("usage: %s [options]\n", argv[0]);
@ -2657,7 +2658,14 @@ static bool llama_download_file(const std::string & url, const std::string & pat
} }
// Set the output file // Set the output file
std::unique_ptr<FILE, decltype(&fclose)> outfile(fopen(path_temporary.c_str(), "wb"), fclose);
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) { if (!outfile) {
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false; return false;

View file

@ -214,7 +214,7 @@ src_func = f"""
""" """
convert_py_pth = pathlib.Path("convert-hf-to-gguf.py") convert_py_pth = pathlib.Path("convert-hf-to-gguf.py")
convert_py = convert_py_pth.read_text() convert_py = convert_py_pth.read_text(encoding="utf-8")
convert_py = re.sub( convert_py = re.sub(
r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)", r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)",
lambda m: m.group(1) + src_func + m.group(3), lambda m: m.group(1) + src_func + m.group(3),
@ -222,7 +222,7 @@ convert_py = re.sub(
flags=re.DOTALL | re.MULTILINE, flags=re.DOTALL | re.MULTILINE,
) )
convert_py_pth.write_text(convert_py) convert_py_pth.write_text(convert_py, encoding="utf-8")
logger.info("+++ convert-hf-to-gguf.py was updated") logger.info("+++ convert-hf-to-gguf.py was updated")

View file

@ -967,7 +967,11 @@ class XverseModel(Model):
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model) tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size # Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size,
# because vocab_size is the count of items, and indexes start at 0.
max_vocab_index = max(tokenizer.get_vocab().values())
if max_vocab_index >= vocab_size:
raise ValueError("Vocabulary size exceeds expected maximum size.")
reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab() added_vocab = tokenizer.get_added_vocab()

View file

@ -17,7 +17,7 @@ Related PRs:
./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 ./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99
# With advanced options # With advanced options
./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --batch-pca 100 ./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100
# To see help message # To see help message
./cvector-generator -h ./cvector-generator -h

View file

@ -40,7 +40,7 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) {
printf("\nexample usage:\n"); printf("\nexample usage:\n");
printf("\n CPU only: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]); printf("\n CPU only: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]);
printf("\n with GPU: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]); printf("\n with GPU: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]);
printf("\n advanced: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --batch-pca 100\n", argv[0]); printf("\n advanced: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100\n", argv[0]);
printf("\n"); printf("\n");
} }
@ -377,8 +377,8 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) {
// create templated prompts // create templated prompts
std::vector<std::string> completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false); std::vector<std::string> completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false);
auto format_template = [](std::string persona, std::string suffix) { auto format_template = [](std::string persona, std::string suffix) {
// entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST]" // entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST] "
return persona + " " + suffix; return persona + suffix;
}; };
for (size_t i = 0; i < positive_prompts.size(); ++i) { for (size_t i = 0; i < positive_prompts.size(); ++i) {
for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) { for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) {

View file

@ -17,9 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines; return lines;
} }
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
for (size_t i = 0; i < tokens.size(); i++) { size_t n_tokens = tokens.size();
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); for (size_t i = 0; i < n_tokens; i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
} }
} }
@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
// try to get sequence embeddings - supported only when pooling_type is not NONE // try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) { GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not. //TODO: I would also add a parameter here to enable normalization or not.
@ -97,6 +92,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}
if (n_ctx > n_ctx_train) { if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx); __func__, n_ctx_train, n_ctx);

View file

@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false); llama_set_causal_attn(ctx, false);
// run model // run model
@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl); llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true); llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode // create generation context
cparams.embeddings = true;
llama_context * ctx = llama_new_context_with_model(mdl, cparams); llama_context * ctx = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ### // ### Embedding/Representation ###

View file

@ -131,23 +131,30 @@ class LlamaState: ObservableObject {
messageLog += "\(text)" messageLog += "\(text)"
Task.detached {
while await llamaContext.n_cur < llamaContext.n_len { while await llamaContext.n_cur < llamaContext.n_len {
let result = await llamaContext.completion_loop() let result = await llamaContext.completion_loop()
messageLog += "\(result)" await MainActor.run {
self.messageLog += "\(result)"
}
} }
let t_end = DispatchTime.now().uptimeNanoseconds let t_end = DispatchTime.now().uptimeNanoseconds
let t_generation = Double(t_end - t_heat_end) / NS_PER_S let t_generation = Double(t_end - t_heat_end) / self.NS_PER_S
let tokens_per_second = Double(await llamaContext.n_len) / t_generation let tokens_per_second = Double(await llamaContext.n_len) / t_generation
await llamaContext.clear() await llamaContext.clear()
messageLog += """
await MainActor.run {
self.messageLog += """
\n \n
Done Done
Heat up took \(t_heat)s Heat up took \(t_heat)s
Generated \(tokens_per_second) t/s\n Generated \(tokens_per_second) t/s\n
""" """
} }
}
}
func bench() async { func bench() async {
guard let llamaContext else { guard let llamaContext else {

View file

@ -16,37 +16,37 @@ struct quant_option {
}; };
static const std::vector<struct quant_option> QUANT_OPTIONS = { static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 3.56G, +0.2166 ppl @ LLaMA-v1-7B", }, { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", }, { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", }, { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", }, { "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , }, { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization", },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 3.41G, +1.6321 ppl @ Llama-3-8B", },
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.74G, +0.6569 ppl @ Llama-3-8B", },
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 4.03G, +0.5562 ppl @ Llama-3-8B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 4.37G, +0.2689 ppl @ Llama-3-8B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 4.58G, +0.1754 ppl @ Llama-3-8B", },
{ "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", }, { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 5.21G, +0.1049 ppl @ Llama-3-8B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.

View file

@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks; return chunks;
} }
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
for (size_t i = 0; i < tokens.size(); i++) { size_t n_tokens = tokens.size();
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); for (size_t i = 0; i < n_tokens; i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
} }
} }
@ -160,6 +161,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}
if (n_ctx > n_ctx_train) { if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx); __func__, n_ctx_train, n_ctx);

View file

@ -634,12 +634,12 @@ return html`
<div> <div>
<div class="grammar"> <div class="grammar">
<label for="template"></label> <label for="template"></label>
<textarea id="grammar" name="grammar" placeholder="Use GBNF or JSON-Scheme + Converter" value="${params.value.grammar}" rows=4 oninput=${updateParams}/> <textarea id="grammar" name="grammar" placeholder="Use GBNF or JSON Schema + Converter" value="${params.value.grammar}" rows=4 oninput=${updateParams}/>
</div> </div>
<div class="grammar-columns"> <div class="grammar-columns">
<div class="json-schema-controls"> <div class="json-schema-controls">
<input type="text" name="prop-order" placeholder="Order: prop1,prop2,prop3" oninput=${updateGrammarJsonSchemaPropOrder} /> <input type="text" name="prop-order" placeholder="Order: prop1,prop2,prop3" oninput=${updateGrammarJsonSchemaPropOrder} />
<button type="button" class="button-grammar" onclick=${convertJSONSchemaGrammar}>Convert JSON-Scheme</button> <button type="button" class="button-grammar" onclick=${convertJSONSchemaGrammar}>Convert JSON Schema</button>
</div> </div>
</div> </div>
</div> </div>

View file

@ -1594,7 +1594,7 @@ struct server_context {
} else { } else {
std::string prompt; std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) { if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
json_value(task.data, "prompt", std::string()); prompt = json_value(task.data, "prompt", std::string());
} }
slot = get_available_slot(prompt); slot = get_available_slot(prompt);

View file

@ -13,16 +13,16 @@ if %errorlevel% neq 0 goto ERROR
:: for FP16 :: for FP16
:: faster for long-prompt inference :: faster for long-prompt inference
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON :: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
:: for FP32 :: for FP32
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release cmake -G "Ninja" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
if %errorlevel% neq 0 goto ERROR if %errorlevel% neq 0 goto ERROR
:: build example/main only :: build example/main only
:: make main :: make main
:: build all binary :: build all binary
make -j cmake --build . -j
if %errorlevel% neq 0 goto ERROR if %errorlevel% neq 0 goto ERROR
cd .. cd ..

View file

@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
} }
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
} }
return row_rounding; return row_rounding;
} }

View file

@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
} }
// Round rows to this value for --split-mode row: // Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) { static int get_mmq_y_host(const int cc) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; return cc >= CC_VOLTA ? 128 : 64;
} }
////////////////////// //////////////////////

View file

@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);

View file

@ -8,6 +8,7 @@
#include <cstdint> #include <cstdint>
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)( typedef void (*load_tiles_mmq_t)(
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)(
typedef void (*vec_dot_mmq_t)( typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0); const int * __restrict__ y, float * __restrict__ sum, const int & k0);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
struct block_q8_1_mmq { struct block_q8_1_mmq {
half2 ds[4]; half2 ds[4];
@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() {
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static constexpr __device__ int get_mmq_y_device(int mmq_x) { return 128;
return mmq_x >= 32 ? 128 : 64;
}
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
static constexpr __device__ int get_mmq_y_device(int mmq_x) { return 128;
return mmq_x >= 32 ? 128 : 64;
}
#else #else
static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
return 64; return 64;
}
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
} }
template<int mmq_x, int mmq_y, int nwarps, bool need_check> template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j >= ne1) { if (j > j_max) {
return; return;
} }
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; const int i = i0 + threadIdx.x;
if (need_check && i >= ne0) { if (need_check && i > i_max) {
continue; continue;
} }
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
} }
} }
} }
template<int mmq_x, int mmq_y, int nwarps, bool need_check> template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
typedef mma_int_C_I16J8 mma_C; typedef mma_int_C_I16J8 mma_C;
const int i0 = threadIdx.y*mma_C::I; const int i0 = threadIdx.y*mma_C::I;
@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
#pragma unroll #pragma unroll
for (int l = 0; l < mma_C::ne; ++l) { for (int l = 0; l < mma_C::ne; ++l) {
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); const int j = j0 + mma_C::get_j(l);
if (j >= ne1) { if (j > j_max) {
continue; continue;
} }
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); const int i = i0 + mma_C::get_i(l);
if (need_check && i >= ne0) { if (need_check && i > i_max) {
continue; continue;
} }
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
} }
} }
} }
@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) {
return false; return false;
} }
template <ggml_type type, int mmq_x, int nwarps, bool need_check> template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __device__ void mul_mat_q_process_tile(
#if defined(RDNA3) || defined(RDNA2) const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
__launch_bounds__(WARP_SIZE*nwarps, 2) const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
#endif // defined(RDNA3) || defined(RDNA2) const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
#else
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device()) {
NO_DEVICE_CODE;
return;
}
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr; constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x); constexpr int mmq_y = get_mmq_y_device();
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr; constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q(
int * tile_x_sc = (int *) (tile_x_dm + txs.dm); int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
const int blocks_per_row_x = ne00 / qk; constexpr int blocks_per_warp = WARP_SIZE / qi;
const int blocks_per_warp = WARP_SIZE / qi;
const int & ne1 = ne11;
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { const int tile_x_max_i = ne01 - it*mmq_y - 1;
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
#pragma unroll #pragma unroll
for (int kr = 0; kr < qr; ++kr) { for (int kr = 0; kr < qr; ++kr) {
@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q(
} }
} }
write_back(sum, dst, ne0, ne1); if (fixup) {
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
}
}
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
#else
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device()) {
NO_DEVICE_CODE;
return;
}
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
{
constexpr bool fixup = false;
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
blockIdx.x, blockIdx.y, 0, ne00/qk);
return;
}
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
const int64_t blocks_per_ne00 = ne00 / qk;
constexpr int blocks_per_warp = WARP_SIZE / qi;
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
// kb0 == k index when doing the matrix multiplication for an output tile.
int kb0_start = kbc % blocks_per_ne00;
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
kbc += blocks_per_ne00;
kbc -= kbc % blocks_per_ne00;
kb0_start = 0;
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
}
if (kbc >= kbc_stop) {
return;
}
const int jt = kbc / (blocks_per_ne00*nty);
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
}
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int blocks_per_warp = WARP_SIZE / qi;
const int64_t blocks_per_ne00 = ne00 / qk;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
const int ntx = (ne11 + mmq_x - 1) / mmq_x;
const int nty = (ne01 + mmq_y - 1) / mmq_y;
bool any_fixup = false;
const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
continue;
}
const int jt = kbc_stop / (blocks_per_ne00*nty);
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if (it != blockIdx.x || jt != blockIdx.y) {
continue;
}
any_fixup = true;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
}
}
}
if (!any_fixup) {
return;
}
dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
const int i_max = ne01 - blockIdx.x*mmq_y - 1;
const int j_max = ne11 - blockIdx.y*mmq_x - 1;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j > j_max) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
} }
struct mmq_args { struct mmq_args {
@ -1987,124 +2139,151 @@ struct mmq_args {
int64_t ne0; int64_t ne0;
}; };
constexpr int mmq_get_nwarps(int mmq_x) {
return mmq_x >= 32 ? 8 : 4;
}
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
const int nwarps = mmq_get_nwarps(mmq_x);
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
} }
template <ggml_type type, int mmq_x, int nwarps> template <ggml_type type, int mmq_x>
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const int mmq_y = get_mmq_y_host(cc, mmq_x); const int nsm = ggml_cuda_info().devices[id].nsm;
const int mmq_y = get_mmq_y_host(cc);
const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) { if (!shmem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true; shmem_limit_raised[id] = true;
} }
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums_xy_tiling(nty, ntx, 1);
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
if (!use_stream_k) {
if (args.ne01 % mmq_y == 0) { if (args.ne01 % mmq_y == 0) {
const bool need_check = false; constexpr bool need_check = false;
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>> mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
} else { } else {
const bool need_check = true; constexpr bool need_check = true;
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>> mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
}
return;
}
const dim3 block_nums_mmq(nsm, 1, 1);
ggml_cuda_pool & pool = ctx.pool();
ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
if (args.ne01 % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
} }
} }
template <ggml_type type> template <ggml_type type>
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm; const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo; const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc, mmq_x_max); const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
int mmq_x_best = 0; int mmq_x_best = 0;
int nwaves_best = INT_MAX; int nparts_best = INT_MAX;
for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; const int nwaves_xy_tiling = ntiles_x*block_num_y;
if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
mmq_x_best = mmq_x; mmq_x_best = mmq_x;
nwaves_best = nwaves; nparts_best = nparts;
} }
} }
switch (mmq_x_best) { switch (mmq_x_best) {
case 8: case 8:
launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream); launch_mul_mat_q<type, 8>(ctx, args, stream);
break; break;
case 16: case 16:
launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream); launch_mul_mat_q<type, 16>(ctx, args, stream);
break; break;
case 24: case 24:
launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream); launch_mul_mat_q<type, 24>(ctx, args, stream);
break; break;
case 32: case 32:
launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream); launch_mul_mat_q<type, 32>(ctx, args, stream);
break; break;
case 40: case 40:
launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream); launch_mul_mat_q<type, 40>(ctx, args, stream);
break; break;
case 48: case 48:
launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream); launch_mul_mat_q<type, 48>(ctx, args, stream);
break; break;
case 56: case 56:
launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream); launch_mul_mat_q<type, 56>(ctx, args, stream);
break; break;
case 64: case 64:
launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream); launch_mul_mat_q<type, 64>(ctx, args, stream);
break; break;
case 72: case 72:
launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream); launch_mul_mat_q<type, 72>(ctx, args, stream);
break; break;
case 80: case 80:
launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream); launch_mul_mat_q<type, 80>(ctx, args, stream);
break; break;
case 88: case 88:
launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream); launch_mul_mat_q<type, 88>(ctx, args, stream);
break; break;
case 96: case 96:
launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream); launch_mul_mat_q<type, 96>(ctx, args, stream);
break; break;
case 104: case 104:
launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream); launch_mul_mat_q<type, 104>(ctx, args, stream);
break; break;
case 112: case 112:
launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream); launch_mul_mat_q<type, 112>(ctx, args, stream);
break; break;
case 120: case 120:
launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream); launch_mul_mat_q<type, 120>(ctx, args, stream);
break; break;
case 128: case 128:
launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream); launch_mul_mat_q<type, 128>(ctx, args, stream);
break; break;
default: default:
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
} }
#define DECL_MMQ_CASE(type) \ #define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \ template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);

View file

@ -735,6 +735,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
} }
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) { static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
}
}
switch (op->op) { switch (op->op) {
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) { switch (ggml_get_unary_op(op)) {

View file

@ -8814,7 +8814,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) #if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
static const int8_t keven_signs_q2xs[1024] = { static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
@ -8947,6 +8947,61 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
*s = 0.125f * hsum_float_8(accumf); *s = 0.125f * hsum_float_8(accumf);
#elif defined(__AVX__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const uint16_t ls1 = aux32[1] >> 28;
const uint16_t ls2 = aux32[3] >> 28;
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const vector int v0 = vec_splats((int32_t)0); const vector int v0 = vec_splats((int32_t)0);
vector float vsumf0 = vec_splats(0.0f); vector float vsumf0 = vec_splats(0.0f);
@ -9290,6 +9345,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
} }
*s = 0.125f * hsum_float_8(accumf); *s = 0.125f * hsum_float_8(accumf);
#elif defined(__AVX__)
const __m128i mone = _mm_set1_epi8(1);
static const char block_sign_shuffle_mask_1[32] = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
};
static const char block_sign_shuffle_mask_2[32] = {
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
};
static const uint8_t bit_selector_mask_bytes[32] = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
static const uint8_t k_bit_helper[32] = {
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
};
const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
const __m128i m511 = _mm_set1_epi16(511);
const __m128i m4 = _mm_set1_epi8(0xf);
const __m128i m1 = _mm_set1_epi8(1);
uint64_t aux64;
// somewhat hacky, but gives a significant boost in performance
__m256i aux_gindex;
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memcpy(&aux64, x[i].scales, 8);
__m128i stmp = _mm_set1_epi64x(aux64);
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
// AVX2 full_signs_1 is full_sign_bits_0 here
// AVX2 full_signs_2 is full_sign_bits_1 here
__m128i signs_0, signs_1;
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
__m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#elif defined(__loongarch_asx) #elif defined(__loongarch_asx)
const __m256i mone = __lasx_xvreplgr2vr_b(1); const __m256i mone = __lasx_xvreplgr2vr_b(1);
@ -9693,6 +9907,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
*s = 0.125f * hsum_float_8(accumf); *s = 0.125f * hsum_float_8(accumf);
#elif defined(__AVX__)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
};
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
const __m128i m4 = _mm_set1_epi8(0xf);
const __m128i m1 = _mm_set1_epi8(1);
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
uint64_t aux64;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict qs = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
const int8_t * restrict q8 = y[i].qs;
memcpy(&aux64, x[i].scales, 8);
const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
qs += 8;
__m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
__m128i aux128_1 = aux128_0;
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
aux128_1 = aux128_0;
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
signs += 4;
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@ -10019,6 +10325,63 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
*s = 0.25f * hsum_float_8(accumf); *s = 0.25f * hsum_float_8(accumf);
#elif defined(__AVX__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[2];
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict gas = x[i].qs + QK_K/4;
const int8_t * restrict q8 = y[i].qs;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
q3 += 8;
const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
q3 += 8;
memcpy(aux32, gas, 8); gas += 8;
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const uint16_t ls1 = aux32[0] >> 28;
const uint16_t ls2 = aux32[1] >> 28;
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = 0.25f * hsum_float_8(accumf);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
@ -10370,6 +10733,112 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
*s = hsum_float_8(accumf); *s = hsum_float_8(accumf);
#elif defined(__AVX__)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
};
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
const __m128i idx_mask = _mm_set1_epi32(256);
typedef union {
__m128i vec[4];
uint32_t index[16];
} index_t;
index_t idx;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict qs = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
const int8_t * restrict q8 = y[i].qs;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
idx.vec[1] = idx.vec[0];
idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
idx.vec[3] = idx.vec[2];
idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
__m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
__m128i aux128_1 = aux128_0;
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
aux128_1 = aux128_0;
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
signs += 4;
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = hsum_float_8(accumf);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@ -10607,6 +11076,14 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
} }
#if defined(__AVX__)
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
#endif
#if defined(__AVX2__) #if defined(__AVX2__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x); const __m256i ax = _mm256_sign_epi8(x, x);
@ -10724,6 +11201,54 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1; *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
#elif defined __AVX__
__m256 accum = _mm256_setzero_ps();
float accum1 = 0;
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint16_t * qh = x[i].qh;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
int sumi1 = 0;
for (int ib = 0; ib < QK_K/32; ib += 2) {
const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
qs += 8;
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
}
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
accum1 += d * sumi1;
}
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const vector unsigned char v0 = vec_splats((unsigned char)0x0); const vector unsigned char v0 = vec_splats((unsigned char)0x0);
const vector unsigned short vsign = vec_splats((unsigned short)0x8000); const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
@ -11062,6 +11587,92 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
#elif defined __AVX__
const __m128i mask = _mm_set1_epi16(0x7);
const __m128i mone = _mm_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib = 0; ib < QK_K/32; ib += 2) {
const __m128i q1b_1_0 = _mm_set_epi64x(
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
const __m128i q1b_1_1 = _mm_set_epi64x(
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
const __m128i q1b_2_0 = _mm_set_epi64x(
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
const __m128i q1b_2_1 = _mm_set_epi64x(
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
__m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
__m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
__m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
__m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
qs += 8; qh += 4;
}
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
}
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
#else #else
int sum1[2], sum2[2], delta[4]; int sum1[2], sum2[2], delta[4];
@ -11192,6 +11803,44 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
*s = hsum_float_8(_mm256_add_ps(accum1, accum2)); *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m128i mone = _mm_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[1].qs + 1);
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
y += 2;
x += 2;
}
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF); const vector signed char lowMask = vec_splats((signed char)0xF);
const vector signed int v0 = vec_splats((int32_t)0); const vector signed int v0 = vec_splats((int32_t)0);
@ -11382,6 +12031,54 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
*s = hsum_float_8(accum); *s = hsum_float_8(accum);
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for (int ibl = 0; ibl < nb; ++ibl) {
const uint8_t * qs = x[ibl].qs;
const int8_t * q8 = y[ibl].qs;
uint16_t sh = x[ibl].scales_h;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib = 0; ib < QK_K/32; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
sh >>= 4;
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
}
__m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
__m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
_mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
}
*s = hsum_float_8(accum);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF); const vector signed char lowMask = vec_splats((signed char)0xF);
const vector int v0 = vec_splats((int32_t)0); const vector int v0 = vec_splats((int32_t)0);

View file

@ -4911,7 +4911,7 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
GGML_TENSOR_BINARY_OP_LOCALS; GGML_TENSOR_BINARY_OP_LOCALS01;
SYCL_CHECK(ggml_sycl_set_device(ctx.device)); SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream(); queue_ptr main_stream = ctx.stream();

View file

@ -589,94 +589,75 @@ namespace dpct
} }
/// dpct device extension /// dpct device extension
class device_ext : public sycl::device class device_ext : public sycl::device {
{
typedef std::mutex mutex_type; typedef std::mutex mutex_type;
public: public:
device_ext() : sycl::device(), _ctx(*this) {} device_ext() : sycl::device() {}
~device_ext() ~device_ext() {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
clear_queues(); clear_queues();
} }
device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) device_ext(const sycl::device &base) : sycl::device(base) {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
init_queues(); init_queues();
} }
int is_native_atomic_supported() { return 0; } int is_native_atomic_supported() { return 0; }
int get_major_version() const int get_major_version() const { return dpct::get_major_version(*this); }
{
return dpct::get_major_version(*this);
}
int get_minor_version() const int get_minor_version() const { return dpct::get_minor_version(*this); }
{
return dpct::get_minor_version(*this);
}
int get_max_compute_units() const int get_max_compute_units() const {
{
return get_device_info().get_max_compute_units(); return get_device_info().get_max_compute_units();
} }
/// Return the maximum clock frequency of this device in KHz. /// Return the maximum clock frequency of this device in KHz.
int get_max_clock_frequency() const int get_max_clock_frequency() const {
{
return get_device_info().get_max_clock_frequency(); return get_device_info().get_max_clock_frequency();
} }
int get_integrated() const { return get_device_info().get_integrated(); } int get_integrated() const { return get_device_info().get_integrated(); }
int get_max_sub_group_size() const int get_max_sub_group_size() const {
{
return get_device_info().get_max_sub_group_size(); return get_device_info().get_max_sub_group_size();
} }
int get_max_register_size_per_work_group() const int get_max_register_size_per_work_group() const {
{
return get_device_info().get_max_register_size_per_work_group(); return get_device_info().get_max_register_size_per_work_group();
} }
int get_max_work_group_size() const int get_max_work_group_size() const {
{
return get_device_info().get_max_work_group_size(); return get_device_info().get_max_work_group_size();
} }
int get_mem_base_addr_align() const int get_mem_base_addr_align() const {
{
return get_info<sycl::info::device::mem_base_addr_align>(); return get_info<sycl::info::device::mem_base_addr_align>();
} }
size_t get_global_mem_size() const size_t get_global_mem_size() const {
{
return get_device_info().get_global_mem_size(); return get_device_info().get_global_mem_size();
} }
size_t get_max_mem_alloc_size() const size_t get_max_mem_alloc_size() const {
{
return get_device_info().get_max_mem_alloc_size(); return get_device_info().get_max_mem_alloc_size();
} }
/// Get the number of bytes of free and total memory on the SYCL device. /// Get the number of bytes of free and total memory on the SYCL device.
/// \param [out] free_memory The number of bytes of free memory on the SYCL device. /// \param [out] free_memory The number of bytes of free memory on the
/// \param [out] total_memory The number of bytes of total memory on the SYCL device. /// SYCL device. \param [out] total_memory The number of bytes of total
void get_memory_info(size_t &free_memory, size_t &total_memory) /// memory on the SYCL device.
{ void get_memory_info(size_t &free_memory, size_t &total_memory) {
total_memory = get_device_info().get_global_mem_size(); total_memory = get_device_info().get_global_mem_size();
const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " const char *warning_info =
"get_memory_info: [warning] ext_intel_free_memory is not "
"supported (export/set ZES_ENABLE_SYSMAN=1 to support), " "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
"use total memory as free memory"; "use total memory as free memory";
#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
if (!has(sycl::aspect::ext_intel_free_memory)) if (!has(sycl::aspect::ext_intel_free_memory)) {
{
std::cerr << warning_info << std::endl; std::cerr << warning_info << std::endl;
free_memory = total_memory; free_memory = total_memory;
} } else {
else
{
free_memory = get_info<sycl::ext::intel::info::device::free_memory>(); free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
} }
#else #else
@ -690,164 +671,139 @@ namespace dpct
#endif #endif
} }
void get_device_info(device_info &out) const void get_device_info(device_info &out) const {
{
dpct::get_device_info(out, *this); dpct::get_device_info(out, *this);
} }
device_info get_device_info() const device_info get_device_info() const {
{
device_info prop; device_info prop;
dpct::get_device_info(prop, *this); dpct::get_device_info(prop, *this);
return prop; return prop;
} }
void reset() void reset() {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
clear_queues(); clear_queues();
init_queues(); init_queues();
} }
sycl::queue &in_order_queue() { return *_q_in_order; } sycl::queue &in_order_queue() { return _q_in_order; }
sycl::queue &out_of_order_queue() { return *_q_out_of_order; } sycl::queue &out_of_order_queue() { return _q_out_of_order; }
sycl::queue &default_queue() sycl::queue &default_queue() { return in_order_queue(); }
{
return in_order_queue();
}
void queues_wait_and_throw() void queues_wait_and_throw() {
{
std::unique_lock<mutex_type> lock(m_mutex); std::unique_lock<mutex_type> lock(m_mutex);
std::vector<std::shared_ptr<sycl::queue>> current_queues(
_queues);
lock.unlock(); lock.unlock();
for (const auto &q : current_queues) for (auto &q : _queues) {
{ q.wait_and_throw();
q->wait_and_throw();
} }
// Guard the destruct of current_queues to make sure the ref count is safe. // Guard the destruct of current_queues to make sure the ref count is
// safe.
lock.lock(); lock.lock();
} }
sycl::queue *create_queue(bool enable_exception_handler = false) sycl::queue create_queue(bool enable_exception_handler = false) {
{
return create_in_order_queue(enable_exception_handler); return create_in_order_queue(enable_exception_handler);
} }
sycl::queue *create_queue(sycl::context context, sycl::device device, sycl::queue create_queue(sycl::device device,
bool enable_exception_handler = false) { bool enable_exception_handler = false) {
return create_in_order_queue(context, device, enable_exception_handler); return create_in_order_queue(device, enable_exception_handler);
} }
sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(enable_exception_handler, return create_queue_impl(enable_exception_handler,
sycl::property::queue::in_order()); sycl::property::queue::in_order());
} }
sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, sycl::queue create_in_order_queue(sycl::device device,
bool enable_exception_handler = false) { bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(context, device, enable_exception_handler, return create_queue_impl(device, enable_exception_handler,
sycl::property::queue::in_order()); sycl::property::queue::in_order());
} }
sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { sycl::queue create_out_of_order_queue(
bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(enable_exception_handler); return create_queue_impl(enable_exception_handler);
} }
void destroy_queue(sycl::queue *&queue) void destroy_queue(sycl::queue queue) {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
_queues.erase(std::remove_if(_queues.begin(), _queues.end(), _queues.clear();
[=](const std::shared_ptr<sycl::queue> &q) -> bool
{
return q.get() == queue;
}),
_queues.end());
queue = nullptr;
} }
void set_saved_queue(sycl::queue *q) void set_saved_queue(sycl::queue q) {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
_saved_queue = q; _saved_queue = q;
} }
sycl::queue *get_saved_queue() const sycl::queue get_saved_queue() const {
{
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return _saved_queue; return _saved_queue;
} }
sycl::context get_context() const { return _ctx; }
private: private:
void clear_queues() void clear_queues() { _queues.clear(); }
{
_queues.clear();
_q_in_order = _q_out_of_order = _saved_queue = nullptr;
}
void init_queues() void init_queues() {
{ _q_in_order =
_q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); create_queue_impl(true, sycl::property::queue::in_order());
_q_out_of_order = create_queue_impl(true); _q_out_of_order = create_queue_impl(true);
_saved_queue = &default_queue(); _saved_queue = default_queue();
} }
/// Caller should acquire resource \p m_mutex before calling this function. /// Caller should acquire resource \p m_mutex before calling this
/// function.
template <class... Properties> template <class... Properties>
sycl::queue *create_queue_impl(bool enable_exception_handler, sycl::queue create_queue_impl(bool enable_exception_handler,
Properties... properties) Properties... properties) {
{
sycl::async_handler eh = {}; sycl::async_handler eh = {};
if (enable_exception_handler) if (enable_exception_handler) {
{
eh = exception_handler; eh = exception_handler;
} }
_queues.push_back(std::make_shared<sycl::queue>( auto q = sycl::queue(*this, eh,
_ctx, *this, eh,
sycl::property_list( sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
sycl::property::queue::enable_profiling(), sycl::property::queue::enable_profiling(),
#endif #endif
properties...))); properties...));
_queues.push_back(q);
return _queues.back().get(); return _queues.back();
} }
template <class... Properties> template <class... Properties>
sycl::queue *create_queue_impl(sycl::context context, sycl::device device, sycl::queue create_queue_impl(sycl::device device,
bool enable_exception_handler, bool enable_exception_handler,
Properties... properties) { Properties... properties) {
sycl::async_handler eh = {}; sycl::async_handler eh = {};
if (enable_exception_handler) { if (enable_exception_handler) {
eh = exception_handler; eh = exception_handler;
} }
_queues.push_back(std::make_shared<sycl::queue>( _queues.push_back(
context, device, eh, sycl::queue(device, eh,
sycl::property_list( sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
sycl::property::queue::enable_profiling(), sycl::property::queue::enable_profiling(),
#endif #endif
properties...))); properties...)));
return _queues.back().get(); return _queues.back();
} }
void get_version(int &major, int &minor) const void get_version(int &major, int &minor) const {
{
detail::get_version(*this, major, minor); detail::get_version(*this, major, minor);
} }
sycl::queue *_q_in_order, *_q_out_of_order; sycl::queue _q_in_order, _q_out_of_order;
sycl::queue *_saved_queue; sycl::queue _saved_queue;
sycl::context _ctx; std::vector<sycl::queue> _queues;
std::vector<std::shared_ptr<sycl::queue>> _queues;
mutable mutex_type m_mutex; mutable mutex_type m_mutex;
}; };
/// device manager /// device manager
class dev_mgr class dev_mgr
{ {

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

6
ggml.h
View file

@ -312,6 +312,12 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb) GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

169
llama.cpp
View file

@ -2293,6 +2293,8 @@ struct llama_vocab {
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
int max_token_len = 0; // used for optimizing longest token search
std::unordered_map<token, id> token_to_id; std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token; std::vector<token_data> id_to_token;
@ -4939,6 +4941,7 @@ static void llm_load_vocab(
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
auto & token_data = vocab.id_to_token[i]; auto & token_data = vocab.id_to_token[i];
token_data.text = std::move(word); token_data.text = std::move(word);
@ -5249,6 +5252,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
if (model.arch == LLM_ARCH_DEEPSEEK2) { if (model.arch == LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
@ -7649,6 +7654,50 @@ struct llm_build_context {
return lctx.inp_s_seq; return lctx.inp_s_seq;
} }
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
// find result_norm tensor for input
struct ggml_tensor * inp = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
inp = gf->nodes[i];
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
break;
} else {
inp = nullptr;
}
}
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
struct ggml_tensor * cur;
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
{
struct ggml_tensor * inp_mean = build_inp_mean();
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls);
} break;
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
default:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
}
cb(cur, "result_embd_pooled", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_llama() { struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -8629,8 +8678,6 @@ struct llm_build_context {
if (model.arch != LLM_ARCH_JINA_BERT_V2) { if (model.arch != LLM_ARCH_JINA_BERT_V2) {
inp_pos = build_inp_pos(); inp_pos = build_inp_pos();
} }
struct ggml_tensor * inp_mean = build_inp_mean();
struct ggml_tensor * inp_cls = build_inp_cls();
// construct input embeddings (token, type, position) // construct input embeddings (token, type, position)
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@ -8805,28 +8852,6 @@ struct llm_build_context {
cur = inpL; cur = inpL;
cb(cur, "result_embd", -1); cb(cur, "result_embd", -1);
// pooling layer
switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
return gf; return gf;
@ -11911,6 +11936,11 @@ static struct ggml_cgraph * llama_build_graph(
GGML_ASSERT(false); GGML_ASSERT(false);
} }
// add on pooling layer
if (lctx.cparams.embeddings) {
result = llm.append_pooling(result);
}
llm.free(); llm.free();
return result; return result;
@ -12000,7 +12030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// (!a || b) is a logical implication (a -> b) // (!a || b) is a logical implication (a -> b)
// !hparams.causal_attn -> !cparams.causal_attn // !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) && (hparams.causal_attn || !cparams.causal_attn) &&
"causal attention with embedding models is not supported" "causal attention is not supported by this model"
); );
if (lctx.inp_KQ_mask) { if (lctx.inp_KQ_mask) {
@ -12132,6 +12162,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = i;
}
}
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
if (kv_self.recurrent) { if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
@ -12193,8 +12254,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead // TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn; const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@ -12324,11 +12385,13 @@ static int llama_decode_internal(
std::vector<std::vector<llama_seq_id>> seq_id; std::vector<std::vector<llama_seq_id>> seq_id;
// count outputs // count outputs
if (batch_all.logits) { if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
n_outputs = n_tokens_all;
} else if (batch_all.logits) {
for (uint32_t i = 0; i < n_tokens_all; ++i) { for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch_all.logits[i] != 0; n_outputs += batch_all.logits[i] != 0;
} }
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { } else if (lctx.logits_all) {
n_outputs = n_tokens_all; n_outputs = n_tokens_all;
} else { } else {
// keep last output only // keep last output only
@ -12459,30 +12522,13 @@ static int llama_decode_internal(
// no output // no output
res = nullptr; res = nullptr;
embd = nullptr; embd = nullptr;
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else if (cparams.embeddings) { } else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors res = nullptr; // do not extract logits for embedding case
int i_embd = gf->n_nodes - 2; embd = gf->nodes[gf->n_nodes - 1];
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { if (strcmp(embd->name, "result_embd_pooled") != 0) {
i_embd = gf->n_nodes - i; embd = gf->nodes[gf->n_nodes - 2];
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
} }
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
} else { } else {
embd = nullptr; // do not extract embeddings when not needed embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@ -12551,11 +12597,10 @@ static int llama_decode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{ {
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear(); embd_seq_out.clear();
@ -13448,7 +13493,7 @@ private:
struct llm_tokenizer_wpm { struct llm_tokenizer_wpm {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
const auto & token_map = vocab.token_to_id; const auto & token_map = vocab.token_to_id;
// normalize and split by whitespace // normalize and split by whitespace
@ -13457,7 +13502,7 @@ struct llm_tokenizer_wpm {
// bos token prepended already // bos token prepended already
// find the longest tokens that form the words // find the longest tokens that form the words
for (const std::string &word : words) { for (const std::string & word : words) {
// skip empty words // skip empty words
if (word.size() == 0) { if (word.size() == 0) {
continue; continue;
@ -13474,7 +13519,7 @@ struct llm_tokenizer_wpm {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
// loop through possible match length // loop through possible match length
bool match = false; bool match = false;
for (int j = n; j > i; j--) { for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
auto it = token_map.find(word1.substr(i, j - i)); auto it = token_map.find(word1.substr(i, j - i));
if (it != token_map.end()) { if (it != token_map.end()) {
output.push_back(it->second); output.push_back(it->second);
@ -13497,7 +13542,8 @@ struct llm_tokenizer_wpm {
} }
} }
std::vector<std::string> preprocess(const std::string & text) { // TODO: reduce string copies by using cpts_offs array
std::vector<std::string> preprocess(const std::string & text) const {
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector<std::string> words(1, ""); std::vector<std::string> words(1, "");
@ -13792,6 +13838,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(vocab.special_cls_id); output.push_back(vocab.special_cls_id);
} }
llm_tokenizer_wpm tokenizer(vocab);
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -13799,7 +13847,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
llm_tokenizer_wpm tokenizer(vocab);
tokenizer.tokenize(raw_text, output); tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
@ -18112,6 +18159,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data; ctx->abort_callback_data = abort_callback_data;
} }
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn; ctx->cparams.causal_attn = causal_attn;
} }

View file

@ -174,6 +174,7 @@ extern "C" {
LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
}; };
enum llama_split_mode { enum llama_split_mode {
@ -293,7 +294,6 @@ extern "C" {
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
// (ignored if no pooling layer)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_base; // RoPE base frequency, 0 = from model
@ -786,6 +786,10 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token). // Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
// Set whether the model is in embeddings mode or not
// If true, embeddings will be returned but logits will not
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set whether to use causal attention or not // Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens // If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

View file

@ -1,2 +1,2 @@
-r ./requirements-convert-legacy-llama.txt -r ./requirements-convert-legacy-llama.txt
torch~=2.1.1 torch~=2.2.1

View file

@ -1,2 +1,2 @@
-r ./requirements-convert-legacy-llama.txt -r ./requirements-convert-legacy-llama.txt
torch~=2.1.1 torch~=2.2.1

View file

@ -1,4 +1,4 @@
numpy~=1.24.4 numpy~=1.26.4
sentencepiece~=0.2.0 sentencepiece~=0.2.0
transformers>=4.40.1,<5.0.0 transformers>=4.40.1,<5.0.0
gguf>=0.1.0 gguf>=0.1.0

View file

@ -785,6 +785,10 @@ struct test_cpy : public test_case {
return VARS_TO_STR3(type_src, type_dst, ne); return VARS_TO_STR3(type_src, type_dst, ne);
} }
double max_nmse_err() override {
return 1e-6;
}
size_t op_size(ggml_tensor * t) override { size_t op_size(ggml_tensor * t) override {
return ggml_nbytes(t) + ggml_nbytes(t->src[0]); return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
} }

View file

@ -7,11 +7,16 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#include "json-schema-to-grammar.h"
#include "unicode.h" #include "unicode.h"
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <vector> #include <vector>
using json = nlohmann::ordered_json;
//#define INCLUDE_FAILING_TESTS 1
static llama_grammar* build_grammar(const std::string & grammar_str) { static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@ -65,8 +70,8 @@ static bool match_string(const std::string & input, llama_grammar* grammar) {
return false; return false;
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) { static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr); fflush(stderr);
auto grammar = build_grammar(grammar_str); auto grammar = build_grammar(grammar_str);
@ -85,6 +90,23 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
if (!matched) { if (!matched) {
fprintf(stderr, "❌ (failed to match)\n"); fprintf(stderr, "❌ (failed to match)\n");
// DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
// DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
FILE* grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
if (grammar_file) {
fprintf(grammar_file, "%s", grammar_str.c_str());
fclose(grammar_file);
}
// DEBUG: Write the test string to test-grammar-integration.string.txt
FILE* string_file = fopen("test-grammar-integration.string.txt", "w");
if (string_file) {
fprintf(string_file, "%s", test_string.c_str());
fclose(string_file);
}
fprintf(stderr, "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following command: ./llama-gbnf-validator test-grammar-integration.grammar.gbnf test-grammar-integration.string.txt\n\n");
} else { } else {
fprintf(stdout, "✅︎\n"); fprintf(stdout, "✅︎\n");
} }
@ -118,6 +140,12 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
// Clean up allocated memory // Clean up allocated memory
llama_grammar_free(grammar); llama_grammar_free(grammar);
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
}
static void test_simple_grammar() { static void test_simple_grammar() {
// Test case for a simple grammar // Test case for a simple grammar
@ -400,10 +428,11 @@ static void test_quantifiers() {
static void test_failure_missing_root() { static void test_failure_missing_root() {
fprintf(stderr, "⚫ Testing missing root node:\n"); fprintf(stderr, "⚫ Testing missing root node:\n");
// Test case for a grammar that is missing a root rule // Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr const std::string grammar_str = R"""(
expr ::= term ("+" term)* rot ::= expr
term ::= number expr ::= term ("+" term)*
number ::= [0-9]+)"""; term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@ -420,10 +449,10 @@ static void test_failure_missing_reference() {
// Test case for a grammar that is missing a referenced rule // Test case for a grammar that is missing a referenced rule
const std::string grammar_str = const std::string grammar_str =
R"""(root ::= expr R"""(root ::= expr
expr ::= term ("+" term)* expr ::= term ("+" term)*
term ::= numero term ::= numero
number ::= [0-9]+)"""; number ::= [0-9]+)""";
fprintf(stderr, " Expected error: "); fprintf(stderr, " Expected error: ");
@ -445,29 +474,558 @@ static void test_failure_left_recursion() {
// Test more complicated left recursion detection // Test more complicated left recursion detection
const std::string medium_str = R"""( const std::string medium_str = R"""(
root ::= asdf root ::= asdf
asdf ::= "a" | asdf "a" asdf ::= "a" | asdf "a"
)"""; )""";
assert(test_build_grammar_fails(medium_str)); assert(test_build_grammar_fails(medium_str));
// Test even more complicated left recursion detection // Test even more complicated left recursion detection
const std::string hard_str = R"""( const std::string hard_str = R"""(
root ::= asdf root ::= asdf
asdf ::= "a" | foo "b" asdf ::= "a" | foo "b"
foo ::= "c" | asdf "d" | "e")"""; foo ::= "c" | asdf "d" | "e")""";
assert(test_build_grammar_fails(hard_str)); assert(test_build_grammar_fails(hard_str));
// Test yet even more complicated left recursion detection // Test yet even more complicated left recursion detection
const std::string hardest_str = R"""( const std::string hardest_str = R"""(
root ::= asdf root ::= asdf
asdf ::= "a" | foo "b" asdf ::= "a" | foo "b"
foo ::= "c" | empty asdf "d" | "e" foo ::= "c" | empty asdf "d" | "e"
empty ::= "blah" | )"""; empty ::= "blah" | )""";
assert(test_build_grammar_fails(hardest_str)); assert(test_build_grammar_fails(hardest_str));
fprintf(stderr, " ✅︎ Passed\n"); fprintf(stderr, " ✅︎ Passed\n");
} }
static void test_json_schema() {
// Note that this is similar to the regular grammar tests,
// but we convert each json schema to a grammar before parsing.
// Otherwise, this test structure is the same.
test_schema(
"empty schema (object)",
// Schema
R"""(
{}
)""",
// Passing strings
{
"{}",
R"""({"foo": "bar"})""",
},
// Failing strings
{
"",
"[]",
"null",
"\"\"",
"true",
}
);
test_schema(
"exotic formats (list)",
// Schema
R"""(
{
"items": [
{ "format": "date" },
{ "format": "uuid" },
{ "format": "time" },
{ "format": "date-time" }
]
}
)""",
// Passing strings
{
// "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
// "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
//R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
//R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
},
// Failing strings
{
R"""(["foo", "bar"])""",
R"""(["12345678-1234-1234-1234-1234567890ab"])""",
}
);
test_schema(
"string",
// Schema
R"""(
{
"type": "string"
}
)""",
// Passing strings
{
"\"foo\"",
"\"bar\"",
"\"\"",
},
// Failing strings
{
"{}",
"\"foo\": \"bar\"",
}
);
test_schema(
"string w/ min length 1",
// Schema
R"""(
{
"type": "string",
"minLength": 1
}
)""",
// Passing strings
{
"\"foo\"",
"\"bar\"",
},
// Failing strings
{
"\"\"",
"{}",
"\"foo\": \"bar\"",
}
);
test_schema(
"string w/ min length 3",
// Schema
R"""(
{
"type": "string",
"minLength": 3
}
)""",
// Passing strings
{
"\"foo\"",
"\"bar\"",
"\"foobar\"",
},
// Failing strings
{
"\"\"",
"\"f\"",
"\"fo\"",
}
);
test_schema(
"string w/ max length",
// Schema
R"""(
{
"type": "string",
"maxLength": 3
}
)""",
// Passing strings
{
"\"foo\"",
"\"bar\"",
"\"\"",
"\"f\"",
"\"fo\"",
},
// Failing strings
{
"\"foobar\"",
}
);
test_schema(
"string w/ min & max length",
// Schema
R"""(
{
"type": "string",
"minLength": 1,
"maxLength": 4
}
)""",
// Passing strings
{
"\"foo\"",
"\"bar\"",
"\"f\"",
"\"barf\"",
},
// Failing strings
{
"\"\"",
"\"barfo\"",
"\"foobar\"",
}
);
test_schema(
"boolean",
// Schema
R"""(
{
"type": "boolean"
}
)""",
// Passing strings
{
"true",
"false",
},
// Failing strings
{
"\"\"",
"\"true\"",
"True",
"FALSE",
}
);
test_schema(
"integer",
// Schema
R"""(
{
"type": "integer"
}
)""",
// Passing strings
{
"0",
"12345",
"1234567890123456"
},
// Failing strings
{
"",
"01",
"007",
"12345678901234567"
}
);
test_schema(
"string const",
// Schema
R"""(
{
"const": "foo"
}
)""",
// Passing strings
{
"\"foo\"",
},
// Failing strings
{
"foo",
"\"bar\"",
}
);
test_schema(
"non-string const",
// Schema
R"""(
{
"const": true
}
)""",
// Passing strings
{
"true",
},
// Failing strings
{
"",
"foo",
"\"true\"",
}
);
test_schema(
"non-string const",
// Schema
R"""(
{
"enum": ["red", "amber", "green", null, 42, ["foo"]]
}
)""",
// Passing strings
{
"\"red\"",
"null",
"42",
"[\"foo\"]",
},
// Failing strings
{
"",
"420",
"true",
"foo",
}
);
test_schema(
"min+max items",
// Schema
R"""(
{
"items": {
"type": ["number", "integer"]
},
"minItems": 3,
"maxItems": 5
}
)""",
// Passing strings
{
"[1, 2, 3]",
"[1, 2, 3, 4]",
"[1, 2, 3, 4, 5]",
},
// Failing strings
{
"[1, 2]",
"[1, 2, 3, 4, 5, 6]",
"1"
}
);
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_schema(
"object properties",
// Schema
R"""(
{
"type": "object",
"properties": {
"number": { "type": "number" },
"street_name": { "type": "string" },
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
}
}
)""",
// Passing strings
{
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
// "By default, leaving out properties is valid"
R"""({ "street_name": "Pennsylvania" })""",
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
// "By extension, even an empty object is valid"
R"""({})""",
// "By default, providing additional properties is valid"
#ifdef INCLUDE_FAILING_TESTS
// TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
// TODO: Spaces should be permitted around enum values, but currently they fail to pass.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
#endif
},
// Failing strings
{
// Change datatype from number to string
R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
// Reorder properties
R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
// Reorder properties
R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
}
);
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_schema(
"object properties, additionalProperties: true",
// Schema
R"""(
{
"type": "object",
"properties": {
"number": { "type": "number" },
"street_name": { "type": "string" },
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
},
"additionalProperties": true
}
)""",
// Passing strings
{
// "By extension, even an empty object is valid"
R"""({})""",
#ifdef INCLUDE_FAILING_TESTS
// TODO: Following line should pass and doesn't
R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
// "By default, leaving out properties is valid"
// TODO: Following line should pass and doesn't
R"""({ "street_name": "Pennsylvania" })""",
// TODO: Following line should pass and doesn't
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
// "By default, providing additional properties is valid"
// TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
// TODO: Spaces should be permitted around enum values, but currently they fail to pass.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
#endif
},
// Failing strings
{
// Change datatype from number to string
R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
// Reorder properties
R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
}
);
// Additional properties: false
test_schema(
"required + optional props each in original order",
// Schema
R"""(
{
"type": "object",
"properties": {
"number": { "type": "number" },
"street_name": { "type": "string" },
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
},
"additionalProperties": false
}
)""",
// Passing strings
{
R"""({ "street_name": "Pennsylvania" })""",
R"""({ "number": 1600, "street_type":"Avenue"})""",
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
#ifdef INCLUDE_FAILING_TESTS
// TODO: Spaces should be permitted around enum values, but currently they fail to pass.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
#endif
},
// Failing strings
{
// Reorder properties
R"""({ "street_type": "Avenue", "number": 1600 })""",
// Add "direction"
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
}
);
test_schema(
"required + optional props each in original order",
// Schema
R"""(
{
"properties": {
"b": {"type": "string"},
"a": {"type": "string"},
"d": {"type": "string"},
"c": {"type": "string"}
},
"required": ["a", "b"],
"additionalProperties": false
}
)""",
// Passing strings
{
R"""({"b": "foo", "a": "bar"})""",
R"""({"b":"foo","a":"bar","d":"qux"})""",
R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
},
// Failing strings
{
R"""({"a": "foo", "b": "bar"})""",
R"""({"b": "bar"})""",
R"""({"a": "foo", "c": "baz"})""",
R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
}
);
// NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
test_schema(
"required props",
// Schema
R"""(
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://example.com/product.schema.json",
"title": "Product",
"description": "A product from Acme's catalog",
"type": "object",
"properties": {
"productId": {
"description": "The unique identifier for a product",
"type": "integer"
},
"productName": {
"description": "Name of the product",
"type": "string"
},
"price": {
"description": "The price of the product",
"type": "number",
"exclusiveMinimum": 0
},
"tags": {
"description": "Tags for the product",
"type": "array",
"items": {
"type": "string"
},
"minItems": 1,
"uniqueItems": true
},
"dimensions": {
"type": "object",
"properties": {
"length": {
"type": "number"
},
"width": {
"type": "number"
},
"height": {
"type": "number"
}
},
"required": [ "length", "width", "height" ]
}
},
"required": [ "productId", "productName", "price" ]
}
)""",
// Passing strings
{
R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
},
// Failing strings
{
R"""({})""", // Missing all required properties
R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties
// TODO: The following line should fail, but currently it passes. `exclusiveMinimum` is not supported, as it would likely be too difficult to implement.
// Perhaps special checks for minimum and maximum values of 0 could be added (since that's relatively easy to do with grammars), but anything else would likely be too complex.
// R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price)
R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId)
R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1
R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order
// TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
// R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
}
);
}
int main() { int main() {
fprintf(stdout, "Running grammar integration tests...\n"); fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar(); test_simple_grammar();
@ -477,6 +1035,7 @@ int main() {
test_failure_missing_root(); test_failure_missing_root();
test_failure_missing_reference(); test_failure_missing_reference();
test_failure_left_recursion(); test_failure_left_recursion();
test_json_schema();
fprintf(stdout, "All tests passed.\n"); fprintf(stdout, "All tests passed.\n");
return 0; return 0;
} }

View file

@ -596,6 +596,7 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) { std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
result.reserve(utf8.size());
size_t offset = 0; size_t offset = 0;
while (offset < utf8.size()) { while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset)); result.push_back(unicode_cpt_from_utf8(utf8, offset));

View file

@ -13,7 +13,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE]; shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;

View file

@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);

View file

@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);

View file

@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);

View file

@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);

View file

@ -7,7 +7,7 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);