Merge remote-tracking branch 'origin/master' into sl/backend-sched

ggml-ci
This commit is contained in:
slaren 2024-01-09 20:42:51 +01:00
commit 2e7814a8c7
28 changed files with 2465 additions and 395 deletions

1
.gitignore vendored
View file

@ -51,6 +51,7 @@ models-mnt
/lookup
/main
/metal
/passkey
/perplexity
/q8dot
/quantize

View file

@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
# Binaries only useful for tests
TEST_TARGETS = \
@ -665,6 +665,9 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ifdef LLAMA_METAL
metal: examples/metal/metal.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

View file

@ -21,7 +21,7 @@ let package = Package(
name: "llama",
dependencies: ["ggml"],
path: ".",
exclude: [],
exclude: ["ggml-metal.metal"],
sources: [
"llama.cpp",
],

View file

@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics
- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow
- Collecting Apple Silicon performance stats:
- M-series: https://github.com/ggerganov/llama.cpp/discussions/4167
- A-series: https://github.com/ggerganov/llama.cpp/discussions/4508
@ -118,6 +119,7 @@ as the main playground for developing new features for the [ggml](https://github
- Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
- Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
@ -135,6 +137,7 @@ as the main playground for developing new features for the [ggml](https://github
- [semperai/amica](https://github.com/semperai/amica)
- [psugihara/FreeChat](https://github.com/psugihara/FreeChat)
- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal)
- [iohub/collama](https://github.com/iohub/coLLaMA)
---

View file

@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--grp-attn-n" || arg == "-gan") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.grp_attn_n = std::stoi(argv[i]);
} else if (arg == "--grp-attn-w" || arg == "-gaw") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.grp_attn_w = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@ -916,6 +930,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
#endif
printf(" -gan N, --grp-attn-n N\n");
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
printf(" -gaw N, --grp-attn-w N\n");
printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w);
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");

View file

@ -63,6 +63,8 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor

File diff suppressed because it is too large Load diff

View file

@ -31,6 +31,7 @@ else()
add_subdirectory(quantize-stats)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(passkey)
add_subdirectory(speculative)
add_subdirectory(lookahead)
add_subdirectory(lookup)

View file

@ -69,6 +69,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(model, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
// initialize the context

View file

@ -1,7 +1,12 @@
# llama.swiftui
# llama.cpp/examples/llama.swiftui
Local inference of llama.cpp on an iPhone.
So far I only tested with starcoder 1B model, but it can most likely handle 7B models as well.
Local inference of llama.cpp on an iPhone. This is a sample app that can be used as a starting
point for more advanced projects.
For usage instructions and performance stats, check the following discussion: https://github.com/ggerganov/llama.cpp/discussions/4508
![image](https://github.com/ggerganov/llama.cpp/assets/1991296/2b40284f-8421-47a2-b634-74eece09a299)
Video demonstration:
https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545

View file

@ -243,6 +243,9 @@ int main(int argc, char ** argv) {
}
auto image_embed = load_image(ctx_llava, &params);
if (!image_embed) {
return 1;
}
// process the prompt
process_prompt(ctx_llava, image_embed, &params, params.prompt);

View file

@ -439,6 +439,21 @@ int main(int argc, char ** argv) {
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
int ga_i = 0;
const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
}
LOG_TEE("\n\n");
if (params.interactive) {
@ -500,37 +515,61 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
if (ga_n == 1) {
// infinite text generation via context shifting
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
if (ctx_guidance) {
n_past_guidance -= n_discard;
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("clear session path\n");
path_session.clear();
}
} else {
// context extension via Self-Extend
while (n_past >= ga_i + ga_w) {
const int ib = (ga_n*ga_i)/ga_w;
const int bd = (ga_w/ga_n)*(ga_n - 1);
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
LOG("\n");
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= bd;
n_past -= n_discard;
ga_i += ga_w/ga_n;
if (ctx_guidance) {
n_past_guidance -= n_discard;
LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("clear session path\n");
path_session.clear();
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)

View file

@ -0,0 +1,5 @@
set(TARGET passkey)
add_executable(${TARGET} passkey.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,12 @@
# llama.cpp/example/passkey
See the following PRs for more info:
- https://github.com/ggerganov/llama.cpp/pull/3856
- https://github.com/ggerganov/llama.cpp/pull/4810
### Usage
```bash
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250
```

View file

@ -0,0 +1,296 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
int main(int argc, char ** argv) {
gpt_params params;
if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
return 1 ;
}
int seed = -1;
int n_junk = 250; // number of times to repeat the junk text
int n_keep = 32; // number of tokens in the prompt prefix
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
int i_pos = -1; // position of the passkey in the junk text
if (argc >= 2) {
params.model = argv[1];
}
if (argc >= 3) {
n_junk = std::stoi(argv[2]);
}
if (argc >= 4) {
n_grp = std::stoi(argv[3]);
}
if (argc >= 5) {
i_pos = std::stoi(argv[4]);
}
if (argc >= 6) {
seed = std::stoi(argv[5]);
}
if (seed == -1) {
seed = time(NULL);
}
srand(seed);
if (i_pos == -1) {
i_pos = rand() % n_junk;
}
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
const std::string prompt_suffix = " What is the pass key? The pass key is";
// generate junk text
params.prompt = prompt_prefix;
const int passkey = rand() % 50000 + 1;
for (int i = 0; i < n_junk; i++) {
if (i % n_junk == i_pos) {
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
}
params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
}
params.prompt += prompt_suffix;
// init LLM
llama_backend_init(params.numa);
// initialize the model
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
// tokenize the prefix and use it as a sink
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
const int n_tokens_all = tokens_list.size();
// we leave a margin of 16 tokens for the generated text - it should contain just the passkey
const int n_predict = 16;
// total length of the sequences including the prompt
const int n_len = n_tokens_all + n_predict;
const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx);
const int n_batch = ctx_params.n_batch;
const int n_batch_grp = ctx_params.n_batch/n_grp;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
// print the prompt token-by-token
LOG_TEE("\n");
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
LOG_TEE("prompt tokens: %d\n", n_tokens_all);
//LOG_TEE("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(512, 0, 1);
int n_past = 0;
// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
if (i > 0 && n_grp > 1) {
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
n_past -= bd;
}
llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
}
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
if (i + n_batch >= n_tokens_all) {
break;
}
}
for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
const int n_discard = n_batch;
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
n_past -= n_discard;
llama_batch_clear(batch);
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
}
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
}
{
const int n_discard = n_past - n_ctx + n_predict;
if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
n_past -= n_discard;
}
}
LOG_TEE("\n");
LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
LOG_TEE("\n");
// main loop
int n_cur = n_tokens_all;
int n_decode = 0;
LOG_TEE("%s", prompt_suffix.c_str());
fflush(stdout);
const auto t_main_start = ggml_time_us();
while (n_cur <= n_len) {
// sample the next token
{
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream?
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
LOG_TEE("\n");
break;
}
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
fflush(stdout);
n_decode += 1;
// prepare the next batch
llama_batch_clear(batch);
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
LOG_TEE("\n");
const auto t_main_end = ggml_time_us();
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
return 0;
}

View file

@ -23,6 +23,7 @@ Command line options:
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
- `--path`: path from which to serve static files (default examples/server/public)
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
- `--embedding`: Enable embedding extraction, Default: disabled.
- `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
@ -174,35 +175,44 @@ node index.js
`system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
*Result JSON:*
### Result JSON:
Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion.
* Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion.
`content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
`stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
`generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`
```
{
"content": "<the token selected by the model>",
"probs": [
{
"prob": float,
"tok_str": "<most likely token>"
},
{
"prob": float,
"tok_str": "<second most likely tonen>"
},
...
]
},
```
Notice that each `probs` is an array of length `n_probs`.
`model`: The path to the model loaded with `-m`
`prompt`: The provided `prompt`
`stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token
`stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered
`stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided
`stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
`timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
`tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
`tokens_evaluated`: Number of tokens evaluated in total from the prompt
`truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`
- `model`: The path to the model loaded with `-m`
- `prompt`: The provided `prompt`
- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token
- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered
- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
- **POST** `/tokenize`: Tokenize a given text.

View file

@ -118,6 +118,7 @@
#endif // defined(GGML_USE_HIPBLAS)
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
@ -479,6 +480,14 @@ typedef struct {
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
#define QR2_XXS 8
#define QI2_XXS (QK_K / (4*QR2_XXS))
typedef struct {
half d;
uint16_t qs[QK_K/8];
} block_iq2_xxs;
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -550,11 +559,12 @@ static std::array<float, GGML_CUDA_MAX_DEVICES> g_default_tensor_split = {};
struct cuda_device_capabilities {
int cc; // compute capability
size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
};
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -583,6 +593,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
return a;
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) a;
bad_arch();
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -591,6 +614,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) x;
bad_arch();
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
@ -1290,6 +1326,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
#endif
}
static const __device__ uint64_t kgrid_iq2xxs[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
};
static const __device__ uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return true;
default:
return false;
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * aux8 = (const uint8_t *)q2;
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
#endif
}
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@ -3823,6 +3981,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
}
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if QK_K == 256
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
#if QR2_XXS == 8
const int ib32 = iqs;
const uint16_t * q2 = bq2->qs + 4*ib32;
const uint8_t * aux8 = (const uint8_t *)q2;
const int8_t * q8 = bq8_1[ib32].qs;
uint32_t aux32 = q2[2] | (q2[3] << 16);
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
for (int j = 0; j < 8; ++j) {
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
aux32 >>= 7;
}
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
return d * sumi;
#else
// iqs is 0...15
const int ib32 = iqs/2;
const int il = iqs%2;
const uint16_t * q2 = bq2->qs + 4*ib32;
const uint8_t * aux8 = (const uint8_t *)q2;
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 8; ++j) {
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
}
return d * (sumi1 + sumi2);
#endif
#else
assert(false);
return 0.f;
#endif
}
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
@ -5204,75 +5411,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.x;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
extern __shared__ half data_soft_max_f16[];
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
// (shared memory) buffer to cache values between iterations:
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
// in that case col_smem == col_data must be enforced to avoid race conditions
float max_val = -INFINITY;
half2 max_val = make_half2(-INFINITY, -INFINITY);
for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
const int col_smem = vals_smem ? col0 + tid : col_data;
const int ix = rowx*ncols_data + col_data;
const int iy = rowy*ncols_data + col_data;
half2 val;
if (need_check && col_data + 0 >= ncols_data) {
val.x = -INFINITY;
} else {
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
}
if (need_check && col_data + WARP_SIZE >= ncols_data) {
val.y = -INFINITY;
} else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
}
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
vals[col_smem] = val;
}
max_val = __hmax2(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = max_val;
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
}
__syncthreads();
max_val = buf[lane_id];
max_val = __half2half2(buf_iw[lane_id]);
max_val = warp_reduce_max(max_val);
} else {
max_val = __half2half2(__hmax(max_val.x, max_val.y));
}
float tmp = 0.f;
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
break;
}
const half2 val = h2exp(vals[col_smem] - max_val);
for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[ix] = val;
vals[col_smem] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = 0.f;
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = tmp;
buf_iw[warp_id] = tmp.x + tmp.y;
}
__syncthreads();
tmp = buf[lane_id];
tmp = __half2half2(buf_iw[lane_id]);
tmp = warp_reduce_sum(tmp);
} else {
tmp = __half2half2(tmp.x + tmp.y);
}
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
const int col_smem = vals_smem ? col0 + tid : col_data;
const int idst = rowx*ncols_data + col_data;
const half2 result = vals[col_smem] * inv_sum;
if (need_check && col_data + 0 >= ncols_data) {
return;
}
dst[idst] = result.x;
if (need_check && col_data + WARP_SIZE >= ncols_data) {
return;
}
dst[idst + WARP_SIZE] = result.y;
}
#else
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
float max_val = -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float inv_tmp = 1.f / tmp;
const float inv_sum = 1.0f / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
return;
}
const int idst = rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
}
}
@ -5662,6 +6027,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
#endif
}
template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@ -5690,6 +6061,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
@ -5719,6 +6092,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
@ -5913,6 +6288,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@ -6552,12 +6936,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem <= g_device_caps[g_main_device].smpb) {
switch (ncols_x) {
case 32:
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 64:
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 128:
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 256:
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 512:
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 1024:
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 2048:
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 4096:
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
default:
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(half);
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
}
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
}
static void im2col_f32_f16_cuda(const float* x, half* dst,
@ -6873,6 +7335,7 @@ void ggml_init_cublas() {
#else
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
g_device_caps[id].smpb = prop.sharedMemPerBlock;
}
for (int id = 0; id < g_device_count; ++id) {
g_default_tensor_split[id] /= total_vram;
@ -7382,6 +7845,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
@ -7402,6 +7866,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
@ -7467,6 +7932,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
@ -7874,7 +8342,21 @@ static void ggml_cuda_op_soft_max(
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
const bool use_f16_soft_max = false;
#else
#ifdef GGML_CUDA_F16
const bool use_f16_soft_max = true;
#else
const bool use_f16_soft_max = false;
#endif // GGML_CUDA_F16
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
} else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
}
(void) dst;
}
@ -8703,6 +9185,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);

View file

@ -88,6 +88,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(get_rows_i32);
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
@ -106,6 +107,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@ -121,6 +123,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@ -133,6 +136,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@ -145,6 +149,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
@ -379,6 +384,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(get_rows_i32);
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
@ -397,6 +403,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@ -412,6 +419,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@ -425,6 +433,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@ -437,6 +446,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
}
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
@ -502,6 +512,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(get_rows_i32);
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
@ -520,6 +531,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@ -535,6 +547,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@ -548,6 +561,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@ -560,6 +574,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
}
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
@ -1541,6 +1556,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1653,6 +1669,12 @@ bool ggml_metal_graph_compute(
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@ -1686,9 +1708,14 @@ bool ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
//src0t == GGML_TYPE_IQ2_XXS ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@ -1778,6 +1805,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1893,6 +1921,12 @@ bool ggml_metal_graph_compute(
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@ -1942,9 +1976,14 @@ bool ggml_metal_graph_compute(
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
//src2t == GGML_TYPE_IQ2_XXS ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@ -1982,6 +2021,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
default: GGML_ASSERT(false && "not implemented");
}

View file

@ -2446,6 +2446,12 @@ typedef struct {
} block_q6_K;
// 210 bytes / block
typedef struct {
half d;
uint16_t qs[QK_K/8];
} block_iq2_xxs;
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
//====================================== dot products =========================
void kernel_mul_mv_q2_K_f32_impl(
@ -3468,6 +3474,221 @@ kernel void kernel_mul_mv_q6_K_f32(
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
constexpr constant static uint64_t kgrid_iq2xxs[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
};
constexpr constant static uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
void kernel_mul_mv_iq2_xxs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne10,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
#if QK_K == 256
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
const int ibl = ib32 / (QK_K / 32);
const int ib = ib32 % (QK_K / 32);
device const block_iq2_xxs * xr = x + ibl;
device const uint16_t * q2 = xr->qs + 4 * ib;
device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
device const uint8_t * aux8 = (device const uint8_t *)q2;
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float sum = 0;
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
sumf[row] += d * sum;
dh += nb*sizeof(block_iq2_xxs)/2;
q2 += nb*sizeof(block_iq2_xxs)/2;
}
y4 += 32 * 32;
}
#else
// TODO
#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
}
}
}
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
kernel void kernel_mul_mv_iq2_xxs_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
@ -3620,8 +3841,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
const half ml = 4.h * dl;
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@ -3688,7 +3909,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
@ -3721,17 +3942,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1);
half sc = scales[(il%2) + 2 * ((il/2))];
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
#else
ql = ql + 16 * (il&1);
half sc = scales[il];
float sc = scales[il];
#endif
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
const half coef = il>1 ? 1.f/16.h : 1.h;
const half ml = d_all * sc * 32.h;
const half dl = d_all * sc * coef;
const float coef = il>1 ? 1.f/16.f : 1.f;
const float ml = d_all * sc * 32.f;
const float dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) {
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@ -3739,6 +3960,31 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
@ -4278,6 +4524,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// matrix-matrix multiplication
@ -4314,6 +4561,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// indirect matrix-matrix multiplication
@ -4362,6 +4610,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
//
// matrix-vector multiplication
@ -5134,3 +5383,68 @@ kernel void kernel_mul_mv_id_q6_K_f32(
tiisg,
sgitg);
}
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
kernel void kernel_mul_mv_id_iq2_xxs_f32(
device const char * ids,
device const char * src1,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
device const char * src00,
device const char * src01,
device const char * src02,
device const char * src03,
device const char * src04,
device const char * src05,
device const char * src06,
device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
kernel_mul_mv_iq2_xxs_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
ne01,
ne02,
ne10,
ne12,
ne0,
ne1,
r2,
r3,
shared_values,
tgpig,
tiisg,
sgitg);
}

View file

@ -2340,6 +2340,138 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
return (n/QK_K*sizeof(block_q6_K));
}
// ====================== "True" 2-bit (de)-quantization
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
(void)x;
(void)y;
(void)k;
assert(k % QK_K == 0);
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
}
static const uint64_t iq2xxs_grid[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
};
static const uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
}
}
}
void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq2_xxs * restrict y = vy;
quantize_row_iq2_xxs_reference(x, y, k);
}
size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
quantize_row_iq2_xxs_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_iq2_xxs));
}
//===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@ -2362,7 +2494,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
x += QK_K;
continue;
}
const float iscale = -128.f/max;
//const float iscale = -128.f/max;
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
const float iscale = -127.f/max;
for (int j = 0; j < QK_K; ++j) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
@ -7065,3 +7199,161 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
}
#endif
static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
};
void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_iq2_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#if defined(__ARM_NEON)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
ggml_int8x16x4_t q2u;
ggml_int8x16x4_t q2s;
ggml_int8x16x4_t q8b;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
float sumf1 = 0, sumf2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
}
sumf += d*(sumf1 + sumf2);
}
*s = 0.25f * sumf;
#elif defined(__AVX2__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = aux32[1] >> 28;
const uint16_t ls2 = aux32[3] >> 28;
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
sumi1 = _mm256_add_epi32(sumi1, p1);
sumi2 = _mm256_add_epi32(sumi2, p2);
}
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#else
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
float sumf = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
int32_t bsum = 0;
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
memcpy(aux32, q2, 2*sizeof(uint32_t));
q2 += 4;
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
int32_t sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
bsum += sumi * ls;
}
sumf += d * bsum;
}
*s = 0.125f * sumf;
#endif
}

View file

@ -165,6 +165,14 @@ typedef struct {
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
// (Almost) "true" 2-bit quantization.
// Due to the need to use blocks as per ggml dsign, it ends up using
// 2.0625 bpw because of the 16-bit scale for each block of 256.
typedef struct {
ggml_fp16_t d;
uint16_t qs[QK_K/8];
} block_iq2_xxs;
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
// Quantization
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
@ -180,6 +188,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@ -194,6 +203,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@ -209,6 +219,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@ -222,3 +233,4 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx,
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);

26
ggml.c
View file

@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
.blck_size = QK_K,
.type_size = sizeof(block_iq2_xxs),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
.from_float = quantize_row_iq2_xxs,
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
.blck_size = QK_K,
@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
}
@ -7457,6 +7469,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
@ -7721,6 +7734,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
@ -7835,6 +7849,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
@ -10476,6 +10491,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
@ -10650,6 +10666,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
@ -10844,6 +10861,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
@ -11480,6 +11498,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
@ -11554,6 +11573,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
@ -18670,6 +18690,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
} break;
case GGML_TYPE_IQ2_XXS:
{
GGML_ASSERT(start % QK_K == 0);
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
} break;
case GGML_TYPE_F16:
{
int elemsize = sizeof(ggml_fp16_t);

3
ggml.h
View file

@ -339,6 +339,7 @@ extern "C" {
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
GGML_TYPE_IQ2_XXS = 16,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -373,6 +374,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
};
// available tensor operations:
@ -2072,6 +2074,7 @@ extern "C" {
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);

View file

@ -1936,6 +1936,28 @@ static void llama_kv_cache_seq_shift(
cache.head = new_head != cache.size ? new_head : 0;
}
static void llama_kv_cache_seq_div(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
{
llama_pos p_old = cache.cells[i].pos;
cache.cells[i].pos /= d;
cache.cells[i].delta += cache.cells[i].pos - p_old;
}
}
}
}
//
// model loading and saving
//
@ -2233,6 +2255,7 @@ struct llama_model_loader {
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
default:
{
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@ -2577,6 +2600,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw";
default: return "unknown, may not work";
}
@ -8490,6 +8514,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:quantized_type = GGML_TYPE_IQ2_XXS; break;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
}
@ -9622,9 +9647,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
}
void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) {
return;
}
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
}
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (d == 1) {
return;
}
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}
// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.

12
llama.h
View file

@ -103,6 +103,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
@ -495,6 +496,17 @@ extern "C" {
llama_pos p1,
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
//
// State / sessions
//

70
scripts/get-pg.sh Executable file
View file

@ -0,0 +1,70 @@
#!/bin/bash
function usage {
echo "usage: <n>$0"
echo "note: n is the number of essays to download"
echo "for specific n, the resulting pg.txt file will have the following number of tokens:"
echo "n | tokens"
echo "--- | ---"
echo "1 | 6230"
echo "2 | 23619"
echo "5 | 25859"
echo "10 | 36888"
echo "15 | 50188"
echo "20 | 59094"
echo "25 | 88764"
echo "30 | 103121"
echo "32 | 108338"
echo "35 | 113403"
echo "40 | 127699"
echo "45 | 135896"
exit 1
}
function has_cmd {
if ! [ -x "$(command -v $1)" ]; then
echo "error: $1 is not available" >&2
exit 1
fi
}
# check for: curl, html2text, tail, sed, fmt
has_cmd curl
has_cmd html2text
has_cmd tail
has_cmd sed
if [ $# -ne 1 ]; then
usage
fi
n=$1
# get urls
urls="$(curl http://www.aaronsw.com/2002/feeds/pgessays.rss | grep html | sed -e "s/.*http/http/" | sed -e "s/html.*/html/" | head -n $n)"
printf "urls:\n%s\n" "$urls"
if [ -f pg.txt ]; then
rm pg.txt
fi
c=1
for url in $urls; do
echo "processing $url"
cc=$(printf "%03d" $c)
curl -L $url | html2text | tail -n +4 | sed -E "s/^[[:space:]]+//g" | fmt -w 80 >> pg-$cc-one.txt
cat pg-$cc-one.txt >> pg.txt
cp -v pg.txt pg-$cc-all.txt
c=$((c+1))
# don't flood the server
sleep 1
done
echo "done. data in pg.txt"
exit 0

View file

@ -455,7 +455,7 @@ struct test_case {
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
//for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//}
@ -1463,6 +1463,7 @@ struct test_moe : public test_case {
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
@ -1597,7 +1598,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
test_cases.emplace_back(new test_soft_max());
std::uniform_int_distribution<> dist_ne1(1, 50);
int exponent = 1;
while (exponent < (1 << 17)) {
std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);
for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
}
exponent <<= 1;
}
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B

View file

@ -134,6 +134,11 @@ int main(int argc, char * argv[]) {
continue;
}
if ((ggml_type)i == GGML_TYPE_IQ2_XXS) {
printf("Skip %s due to missing quantization functionality\n", ggml_type_name((ggml_type) i));
continue;
}
printf("Testing %s\n", ggml_type_name((ggml_type) i));
if (qfns.from_float && qfns.to_float) {