From 3e945cc1e9c06d2001031360e4e303e9548fb02c Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 18 Jan 2024 19:18:21 +0200 Subject: [PATCH 01/28] HellaSwag: speed up by parallelizing log-prob evaluation (#5020) For Mistral-7B and fp16, time on my system goes down from 536 seconds to 423 seconds for the full evaluation dataset (10042 tasks). Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 80 ++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ea2c8026c..9498dd535 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -444,6 +445,48 @@ static std::vector evaluate_tokens(llama_context * ctx, std::vector return result; } +static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, + const std::vector>& eval_pairs, std::vector& eval_results) { + constexpr int k_token_chunk = 4; + if (eval_results.size() != eval_pairs.size()) { + eval_results.resize(eval_pairs.size()); + } + if (eval_pairs.empty()) return; + + size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size()); + + std::atomic counter(0); + auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () { + float local_logprobs[k_token_chunk]; + while (true) { + size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed); + if (first >= eval_results.size()) break; + size_t last = std::min(first + k_token_chunk, eval_results.size()); + for (size_t i = first; i < last; ++i) { + auto logits = batch_logits + eval_pairs[i].first * n_vocab; + float max_logit = logits[0]; + for (int j = 1; j < n_vocab; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + float sum_p = 0.f; + for (int j = 0; j < n_vocab; ++j) { + sum_p += expf(logits[j] - max_logit); + } + local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p); + } + std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float)); + } + }; + + for (size_t it = 0; it < max_threads; ++it) { + workers[it] = std::thread(compute); + } + for (size_t it = 0; it < max_threads; ++it) { + workers[it].join(); + } + +} + static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculates hellaswag score (acc_norm) from prompt // @@ -574,6 +617,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::vector tok_logits(n_vocab); std::vector batch_logits(n_ctx*n_vocab); + std::vector> eval_pairs; + std::vector eval_results; + std::vector workers(std::thread::hardware_concurrency()); + auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); @@ -654,6 +701,24 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } + // Compute log-probs in parallel + // First we collect all tasks + eval_pairs.clear(); + for (size_t i = i0; i < i1; ++i) { + auto & hs_cur = hs_data[i]; + size_t li = hs_cur.common_prefix; + for (int s = 0; s < 4; ++s) { + for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { + eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1])); + } + ++li; + } + } + // Then we do the actual calculation + hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); + + size_t ir = 0; + // compute the logprobs for each ending of the decoded tasks for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i]; @@ -662,26 +727,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const auto first_probs = softmax(tok_logits); - size_t li = hs_cur.common_prefix; // logits index in the batch - for (int s = 0; s < 4; ++s) { hs_cur.ending_logprob_count[s] = 1; hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); - - // Calculate the logprobs over the ending for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); - - const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]]; - - hs_cur.ending_logprob[s] += std::log(prob); + hs_cur.ending_logprob[s] += eval_results[ir++]; hs_cur.ending_logprob_count[s]++; } - - // account that we skip the last token in the ending - ++li; - - // Calculate the mean token logprob for acc_norm hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; } From b46757735d30f5c6ed4f20ebeccc684e02d4f3bf Mon Sep 17 00:00:00 2001 From: David Sommers <12738+databyte@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:20:59 -0500 Subject: [PATCH 02/28] convert.py : fix llama/llama2 conversion due to vocab_size=-1 (#5019) PR #4818 (merged last week) reintroduced a config check for vocab_size that was addressed in PR #4258 (merged 2023-11-30). Without the fix, llama2 models can't be converted. The error is: `ValueError: The model's vocab size is set to -1 in params.json. Please update it manually. Maybe 32000?` --- convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert.py b/convert.py index e38ee5315..980e6fc72 100755 --- a/convert.py +++ b/convert.py @@ -348,7 +348,7 @@ class Params: f_rope_freq_base = 1e6 return Params( - n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), + n_vocab=model["tok_embeddings.weight"].shape[0], n_embd=config["dim"], n_layer=config["n_layers"], n_ctx=n_ctx, From e9240cdfa06a50c1b5dbafa367cb8cd698e65103 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 20:45:39 +0200 Subject: [PATCH 03/28] scripts : add get-winogrande.sh --- scripts/get-hellaswag.sh | 2 +- scripts/get-wikitext-2.sh | 7 +++++++ scripts/get-winogrande.sh | 10 ++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100755 scripts/get-winogrande.sh diff --git a/scripts/get-hellaswag.sh b/scripts/get-hellaswag.sh index ef8dcceb0..121979fe2 100755 --- a/scripts/get-hellaswag.sh +++ b/scripts/get-hellaswag.sh @@ -4,7 +4,7 @@ wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag echo "Usage:" echo "" -echo " ./perplexity --hellaswag --hellaswag-tasks N -f hellaswag_val_full.txt -m modelfile.gguf" +echo " ./perplexity -m model.gguf -f hellaswag_val_full.txt --hellaswag [--hellaswag-tasks N] [other params]" echo "" exit 0 diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index 98aec3e3e..ff96f331e 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,3 +1,10 @@ #!/bin/bash wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip + +echo "Usage:" +echo "" +echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]" +echo "" + +exit 0 diff --git a/scripts/get-winogrande.sh b/scripts/get-winogrande.sh new file mode 100755 index 000000000..5f234468e --- /dev/null +++ b/scripts/get-winogrande.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv + +echo "Usage:" +echo "" +echo " ./perplexity -m model.gguf -f winogrande-debiased-eval.csv --winogrande [--winogrande-tasks N] [other params]" +echo "" + +exit 0 From d391ae9b4919e24624cc963d82162450848beaf4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 20:49:00 +0200 Subject: [PATCH 04/28] perplexity : fix winogrande N tasks option --- examples/perplexity/perplexity.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9498dd535..f72ea6d1c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -865,7 +865,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { } float scale = 1/(1.f + (float)rng.max()); std::vector selected; - selected.reserve(params.winogrande_tasks); + selected.resize(params.winogrande_tasks); for (int i = 0; i < int(params.winogrande_tasks); ++i) { int j = int(scale*rng()*aux.size()); selected[i] = std::move(data[aux[j]]); From 2d5419d08ab1131623e6a1d554607b7663435e87 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 21:45:51 +0200 Subject: [PATCH 05/28] imatrix : fix assert for src0 non-cont check --- examples/imatrix/imatrix.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index af78711c5..5a3d30b88 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -80,7 +80,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // for simplicity, always copy src0 to host, because it is small // take into account that src0 is not contiguous! GGML_ASSERT(src0->ne[1] == src1->ne[1]); - GGML_ASSERT(n_as*ggml_nrows(src0)); + GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int))); m_ids.resize(ggml_nbytes(src0)/sizeof(int)); ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0)); From 96d7f56d2918ffde1995dbb32392571deb76d7fc Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 18 Jan 2024 21:12:15 +0100 Subject: [PATCH 06/28] llama : fix mlock with no-mmap with Metal (#5025) --- llama.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index d28382f7d..f1d00a96c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1599,7 +1599,7 @@ struct llama_model { std::unique_ptr mapping; // objects representing data potentially being locked in memory - llama_mlock mlock_buf; + std::vector> mlock_bufs; llama_mlock mlock_mmap; // for quantize-stats only @@ -3815,8 +3815,10 @@ static bool llm_load_tensors( else { buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) { - model.mlock_buf.init (ggml_backend_buffer_get_base(buf)); - model.mlock_buf.grow_to(ggml_backend_buffer_get_size(buf)); + model.mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = model.mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } } if (buf == nullptr) { From 821f0a271e7c9ee737945245dd7abfa22cc9b5b0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 18 Jan 2024 21:33:05 +0100 Subject: [PATCH 07/28] server : defer tasks when "slot unavailable" (#5018) * server: defer task when no slot is available * remove unnecessary log --------- Co-authored-by: Xuan Son Nguyen --- examples/server/server.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93f999298..0462fbd24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1558,6 +1558,7 @@ struct llama_server_context void process_tasks() { std::unique_lock lock(mutex_tasks); + std::vector deferred_tasks; while (!queue_tasks.empty()) { task_server task = queue_tasks.front(); @@ -1568,9 +1569,8 @@ struct llama_server_context llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); if (slot == nullptr) { - LOG_TEE("slot unavailable\n"); - // send error result - send_error(task, "slot unavailable"); + // if no slot is available, we defer this task for processing later + deferred_tasks.push_back(task); break; } @@ -1616,6 +1616,12 @@ struct llama_server_context } } + // add all the deferred tasks back the the queue + for (task_server &task : deferred_tasks) + { + queue_tasks.push_back(task); + } + // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue std::vector agg_results; auto queue_iterator = queue_multitasks.begin(); From 9b6ea4263ab45e02ff905bf7a29dc143ca1facc3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 23:36:07 +0200 Subject: [PATCH 08/28] cmake : add ggml public headers (#5011) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bd640966..3fc65eaf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -846,7 +846,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) -set(GGML_PUBLIC_HEADERS "ggml.h" +set(GGML_PUBLIC_HEADERS "ggml.h" "ggml-alloc.h" "ggml-backend.h" "${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}" "${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}") From 57e2a7a52a819883f40dada8a2edc24ecf48186b Mon Sep 17 00:00:00 2001 From: John <78893154+cmp-nct@users.noreply.github.com> Date: Thu, 18 Jan 2024 23:12:15 +0100 Subject: [PATCH 09/28] llama : fix falcon arch for tied output embeddings (#4978) * falcon arch fix for tied output embeddings * Update llama.cpp Co-authored-by: Georgi Gerganov * Update llama.cpp * Update llama.cpp Co-authored-by: Georgi Gerganov * Update llama.cpp --------- Co-authored-by: Georgi Gerganov --- llama.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index f1d00a96c..47b4384a8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3438,7 +3438,12 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_OUTPUT, "weight").c_str()) >= 0) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } else { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU + ml.n_created--; // artificial tensor + } } for (int i = 0; i < n_layer; ++i) { From 8b20858e5e9c44b99b4b31ae9c40b8f20d01d94f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 10:45:06 +0200 Subject: [PATCH 10/28] perplexity : faster Winogrande via batching (#5024) * perplexity : faster Winogrande via batching ggml-ci * perplexity : remove unused function * perplexity : only tokenize selected tasks for Winogrande --- examples/perplexity/perplexity.cpp | 287 ++++++++++++++++------------- 1 file changed, 160 insertions(+), 127 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f72ea6d1c..df902fb1c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, ppl, logit_history, prob_history}; } -static std::vector evaluate_tokens(llama_context * ctx, std::vector & tokens, - int n_past, int n_batch, int n_vocab) { - std::vector result; - result.reserve(tokens.size() * n_vocab); - size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; - for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { - size_t n_tokens = tokens.size() - i_chunk * n_batch; - n_tokens = std::min(n_tokens, size_t(n_batch)); - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return {}; +static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; } - const auto logits = llama_get_logits(ctx); - result.insert(result.end(), logits, logits + n_tokens * n_vocab); - - n_past += n_tokens; + memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); } - return result; + + return true; } static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, @@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // determine the common prefix of the endings hs_cur.common_prefix = 0; - hs_cur.required_tokens = 0; for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || @@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int max_tasks_per_batch = params.n_parallel; + const int max_tasks_per_batch = 32; const int max_seq = 4*max_tasks_per_batch; llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); - std::vector batch_logits(n_ctx*n_vocab); + std::vector batch_logits(n_vocab*n_ctx); std::vector> eval_pairs; std::vector eval_results; std::vector workers(std::thread::hardware_concurrency()); - auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; - - const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); - return false; - } - - memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); - } - - return true; - }; - for (size_t i0 = 0; i0 < hs_task_count; i0++) { int n_cur = 0; @@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, n_batch)) { + if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } @@ -772,6 +749,13 @@ struct winogrande_entry { std::string second; std::array choices; int answer; + + size_t i_batch; + size_t common_prefix; + size_t required_tokens; + size_t n_base1; // number of tokens for context + choice 1 + size_t n_base2; // number of tokens for context + choice 2 + std::vector seq_tokens[2]; }; static std::vector load_winogrande_from_csv(const std::string& prompt) { @@ -875,115 +859,164 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { data = std::move(selected); } + fprintf(stderr, "%s : tokenizing selected tasks\n", __func__); + // This is needed as usual for LLaMA models const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + for (auto & task : data) { + task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos); + task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos); + + task.common_prefix = 0; + for (size_t k = 0; k < task.seq_tokens[0].size(); k++) { + if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) { + break; + } + task.common_prefix++; + } + + task.required_tokens = task.common_prefix + + task.seq_tokens[0].size() - task.common_prefix + + task.seq_tokens[1].size() - task.common_prefix; + + task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size(); + task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size(); + } + fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx); + const int n_batch = params.n_batch; + + const int max_tasks_per_batch = 128; + const int max_seq = 2*max_tasks_per_batch; + + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); + std::vector batch_logits(n_vocab*n_ctx); int n_correct = 0; int n_done = 0; - for (size_t task_idx = 0; task_idx < data.size(); task_idx++) { - const auto& task = data[task_idx]; + for (size_t i0 = 0; i0 < data.size(); i0++) { + int n_cur = 0; - auto base_context = ::llama_tokenize(ctx, task.first, add_bos); - auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos); - auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos); + size_t i1 = i0; + size_t i_batch = 0; - auto sentence_1st = task.first + task.choices[0] + task.second; - auto sentence_2nd = task.first + task.choices[1] + task.second; - auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos); - auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos); + llama_batch_clear(batch); - if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size()); + while (n_cur + (int) data[i1].required_tokens <= n_ctx) { + const int s0 = 2*(i1 - i0); + if (s0 + 2 > max_seq) { + break; + } + + for (size_t i = 0; i < data[i1].common_prefix; ++i) { + llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false); + } + batch.logits[batch.n_tokens - 1] = true; + + for (int s = 0; s < 2; ++s) { + for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { + llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + } + } + + data[i1].i_batch = i_batch; + i_batch += data[i1].required_tokens; + + n_cur += data[i1].required_tokens; + if (++i1 == data.size()) { + break; + } + } + + if (i0 == i1) { + fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); return; } - auto query_1st_size = query_1st.size(); - auto query_2nd_size = query_2nd.size(); - - // Speedup small evaluations by evaluating atleast 32 tokens - // For Winogrande this seems to slow it down rather than speed it up. - //if (query_1st.size() < 32) query_1st.resize(32); - //if (query_2nd.size() < 32) query_2nd.resize(32); - llama_kv_cache_clear(ctx); - auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab); - llama_kv_cache_clear(ctx); - auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab); - - if (logits_1st.empty() || logits_2nd.empty()) { - fprintf(stderr, "%s : failed to eval\n", __func__); + // decode all tasks [i0, i1) + if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } - bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx && - query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx; + for (size_t i = i0; i < i1; ++i) { + auto & task = data[i]; - float score_1st = 0; - bool is_nan_1st = false; - const auto& base_1 = skip_choice ? base_ctx_1st : base_context; - const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0; - for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) { - std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[query_1st[j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, sentence_1st.c_str(), base_context.size()); - is_nan_1st = true; - break; + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + + float score_1st = 0; + bool is_nan_1st = false; + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + size_t li = n_base1 - 1; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float)); + const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]]; + if (std::isnan(prob) || !prob) { + fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, + prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1); + is_nan_1st = true; + break; + } + score_1st += std::log(prob); } - score_1st += std::log(prob); - } - score_1st /= (query_1st_size - base_1.size() - last_1st); + score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st); - float score_2nd = 0; - bool is_nan_2nd = false; - const auto& base_2 = skip_choice ? base_ctx_2nd : base_context; - const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0; - for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) { - std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[query_2nd[j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, sentence_2nd.c_str(), base_context.size()); - is_nan_2nd = true; - break; + float score_2nd = 0; + bool is_nan_2nd = false; + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float)); + const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]]; + if (std::isnan(prob) || !prob) { + fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, + prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2); + is_nan_2nd = true; + break; + } + score_2nd += std::log(prob); } - score_2nd += std::log(prob); - } - score_2nd /= (query_2nd_size - base_2.size() - last_2nd); + score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd); - if (is_nan_1st || is_nan_2nd) { - continue; + if (is_nan_1st || is_nan_2nd) { + continue; + } + + if (std::isnan(score_1st) || std::isnan(score_2nd)) { + printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd); + printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size()); + printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size()); + printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix); + printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice); + continue; + } + + int result = score_1st > score_2nd ? 1 : 2; + + if (result == task.answer) { + ++n_correct; + } + ++n_done; + + // Print the accumulated accuracy mean x 100 + printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); + fflush(stdout); } - if (std::isnan(score_1st) || std::isnan(score_2nd)) { - printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd); - printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size); - printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size); - printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size()); - printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice); - continue; - } - - int result = score_1st > score_2nd ? 1 : 2; - - if (result == task.answer) { - ++n_correct; - } - ++n_done; - - // Print the accumulated accuracy mean x 100 - printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer); - fflush(stdout); + i0 = i1 - 1; } printf("\n"); From 993fba81807e55d27b570945af8e416d535eced1 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:02:39 +0200 Subject: [PATCH 11/28] perplexity: avoid unnecessary alloocations and logit copies (#5035) Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index df902fb1c..292502f87 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double nll = 0.0; double nll2 = 0.0; + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + if (num_batches > 1) { + logits.reserve((size_t)n_ctx * n_vocab); + } + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); std::vector workers(std::thread::hardware_concurrency() - 1); @@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int start = i * n_ctx; const int end = start + n_ctx; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; - - std::vector logits; - const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache @@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + if (num_batches > 1) { + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } } const auto t_end = std::chrono::high_resolution_clock::now(); @@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. const int first = n_ctx/2; - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); count += n_ctx - first - 1; @@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); + + logits.clear(); } printf("\n"); From 2b3b999cacc7ad1207c32fbdf3479a19c06e1a34 Mon Sep 17 00:00:00 2001 From: chiranko <96988916+chiranko@users.noreply.github.com> Date: Fri, 19 Jan 2024 17:07:27 +0800 Subject: [PATCH 12/28] llama : add CodeShell support (#5016) * llama: add codeshell support * llama.cpp: fix codeshell with NeoX rope Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf.py | 67 ++++++++++++ gguf-py/gguf/constants.py | 19 ++++ gguf-py/gguf/tensor_mapping.py | 1 + llama.cpp | 181 +++++++++++++++++++++++++++++++++ 4 files changed, 268 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1178d63a2..aae3a5e87 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -197,6 +197,8 @@ class Model: return Phi2Model if model_architecture == "PlamoForCausalLM": return PlamoModel + if model_architecture == "CodeShellForCausalLM": + return CodeShellModel return Model def _is_model_safetensors(self) -> bool: @@ -242,6 +244,8 @@ class Model: return gguf.MODEL_ARCH.PHI2 if arch == "PlamoForCausalLM": return gguf.MODEL_ARCH.PLAMO + if arch == "CodeShellForCausalLM": + return gguf.MODEL_ARCH.CODESHELL raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1175,6 +1179,69 @@ class PlamoModel(Model): self.gguf_writer.add_tensor(new_name, data) +class CodeShellModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("CodeShell") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(10000.0) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + tensors = dict(self.get_tensors()) + has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys() + for name, data_torch in tensors.items(): + # we don't need these + if name.endswith((".attn.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "transformer.wte.weight": + self.gguf_writer.add_tensor("output.weight", data) + print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 972b4e9a7..95c58b419 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -99,6 +99,7 @@ class MODEL_ARCH(IntEnum): QWEN = auto() PHI2 = auto() PLAMO = auto() + CODESHELL = auto() class MODEL_TENSOR(IntEnum): @@ -147,6 +148,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -396,6 +398,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, ] # TODO } @@ -417,6 +432,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5b146106..de177af13 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -154,6 +154,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo + "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell ), # Feed-forward norm diff --git a/llama.cpp b/llama.cpp index 47b4384a8..1cee5a791 100644 --- a/llama.cpp +++ b/llama.cpp @@ -194,6 +194,7 @@ enum llm_arch { LLM_ARCH_QWEN, LLM_ARCH_PHI2, LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, LLM_ARCH_UNKNOWN, }; @@ -213,6 +214,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, }; enum llm_kv { @@ -600,6 +602,26 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_CODESHELL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -2877,6 +2899,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CODESHELL: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3784,6 +3814,42 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_CODESHELL: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5965,6 +6031,117 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_codeshell() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(tmpq, "tmpq", il); + cb(tmpk, "tmpk", il); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + // add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph( @@ -6159,6 +6336,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_gpt2(); } break; + case LLM_ARCH_CODESHELL: + { + result = llm.build_codeshell(); + } break; default: GGML_ASSERT(false); } From 7051aacfac0057fa5fac9ea46c55bffc3892d810 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:39:11 +0200 Subject: [PATCH 13/28] winogrande: evaluate log-probs in parallel (#5036) This is a relatively minor performance tweak resulting in ~10% speedup on my system. Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 71 ++++++++++++++---------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 292502f87..b07320190 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -458,7 +458,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< return true; } -static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, +static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, const std::vector>& eval_pairs, std::vector& eval_results) { constexpr int k_token_chunk = 4; if (eval_results.size() != eval_pairs.size()) { @@ -700,7 +700,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } } // Then we do the actual calculation - hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); + compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); size_t ir = 0; @@ -906,6 +906,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { std::vector tok_logits(n_vocab); std::vector batch_logits(n_vocab*n_ctx); + std::vector> eval_pairs; + std::vector eval_results; + std::vector workers(std::thread::hardware_concurrency()); + int n_correct = 0; int n_done = 0; @@ -956,6 +960,30 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { return; } + eval_pairs.clear(); + for (size_t i = i0; i < i1; ++i) { + auto & task = data[i]; + + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + size_t li = n_base1 - 1; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { + eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1])); + } + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { + eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1])); + } + } + compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); + + size_t ir = 0; for (size_t i = i0; i < i1; ++i) { auto & task = data[i]; @@ -964,54 +992,21 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; float score_1st = 0; - bool is_nan_1st = false; const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; - size_t li = n_base1 - 1; for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1); - is_nan_1st = true; - break; - } - score_1st += std::log(prob); + score_1st += eval_results[ir++]; } score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st); float score_2nd = 0; - bool is_nan_2nd = false; const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; - li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1; for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2); - is_nan_2nd = true; - break; - } - score_2nd += std::log(prob); + score_2nd += eval_results[ir++]; } score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd); - if (is_nan_1st || is_nan_2nd) { - continue; - } - - if (std::isnan(score_1st) || std::isnan(score_2nd)) { - printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd); - printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size()); - printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size()); - printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix); - printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice); - continue; - } - int result = score_1st > score_2nd ? 1 : 2; if (result == task.answer) { @@ -1019,7 +1014,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { } ++n_done; - // Print the accumulated accuracy mean x 100 + // print the accumulated accuracy mean x 100 printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); fflush(stdout); } From de9a147df14e62f54f879d2d15e6c4793107f4fc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 13:52:22 +0200 Subject: [PATCH 14/28] py : fix flake8 lint --- convert-hf-to-gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index aae3a5e87..d2d6948d8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1179,6 +1179,7 @@ class PlamoModel(Model): self.gguf_writer.add_tensor(new_name, data) + class CodeShellModel(Model): def set_gguf_parameters(self): block_count = self.hparams["n_layer"] From 9b75cb2b3ccbed3df2e14c1202168db3e5145095 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Fri, 19 Jan 2024 19:53:13 +0800 Subject: [PATCH 15/28] llama : support upcoming Qwen2 (#5037) --- convert-hf-to-gguf.py | 4 + gguf-py/gguf/constants.py | 16 ++++ llama.cpp | 191 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 211 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d2d6948d8..5cb3e63fb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -189,6 +189,8 @@ class Model: return StableLMModel if model_architecture == "QWenLMHeadModel": return QwenModel + if model_architecture == "Qwen2ForCausalLM": + return Model if model_architecture == "MixtralForCausalLM": return MixtralModel if model_architecture == "GPT2LMHeadModel": @@ -236,6 +238,8 @@ class Model: return gguf.MODEL_ARCH.STABLELM if arch == "QWenLMHeadModel": return gguf.MODEL_ARCH.QWEN + if arch == "Qwen2ForCausalLM": + return gguf.MODEL_ARCH.QWEN2 if arch == "MixtralForCausalLM": return gguf.MODEL_ARCH.LLAMA if arch == "GPT2LMHeadModel": diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 95c58b419..2d9c33c7d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -97,6 +97,7 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() + QWEN2 = auto() PHI2 = auto() PLAMO = auto() CODESHELL = auto() @@ -146,6 +147,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.CODESHELL: "codeshell", @@ -358,6 +360,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/llama.cpp b/llama.cpp index 1cee5a791..90579ac85 100644 --- a/llama.cpp +++ b/llama.cpp @@ -192,6 +192,7 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, LLM_ARCH_PHI2, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, @@ -212,6 +213,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, @@ -568,6 +570,23 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_PHI2, { @@ -2869,6 +2888,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3704,6 +3734,41 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); } } break; + case LLM_ARCH_QWEN2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; case LLM_ARCH_PHI2: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5698,6 +5763,128 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_qwen2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6324,6 +6511,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; + case LLM_ARCH_QWEN2: + { + result = llm.build_qwen2(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); From a5cacb22b2114fd9adf61c00cbb237384d86bced Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 15:24:47 +0200 Subject: [PATCH 16/28] imatrix : add README.md --- examples/imatrix/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 examples/imatrix/README.md diff --git a/examples/imatrix/README.md b/examples/imatrix/README.md new file mode 100644 index 000000000..578e8fc27 --- /dev/null +++ b/examples/imatrix/README.md @@ -0,0 +1,32 @@ +# llama.cpp/examples/imatrix + +Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantum models. +More information is available here: https://github.com/ggerganov/llama.cpp/pull/4861 + +## Usage + +``` +./imatrix -m -f [-o ] [--verbosity ] + [-ofreq num_chunks] [-ow <0 or 1>] [other common params] +``` + +Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory. +The parameters in square brackets are optional and have the following meaning: +* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used. +* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`. +* `-ofreq` (or `--output-frequency`) specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) +* `-ow` (or `--output-weight`) specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default. + +For faster computation, make sure to use GPU offloading via the `-ngl` argument + +## Example + +```bash +LLAMA_CUBLAS=1 make -j + +# generate importance matrix (imatrix.dat) +./imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99 + +# use the imatrix to perform a Q4_K_M quantization +./quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m +``` From 381ee195721d8e747ee31a60c0751822b3072f02 Mon Sep 17 00:00:00 2001 From: Uzo Nweke Date: Fri, 19 Jan 2024 13:20:50 -0500 Subject: [PATCH 17/28] finetune : fix ggml_allocr lifetimes (tmp workaround) (#5033) * Fix issue with alloc causing max_compute_size to be calculated * remove ggml_allocr_free as suggested in issue #4791 --- examples/train-text-from-scratch/train-text-from-scratch.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 4a9a2340b..eee9d4de3 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -263,7 +263,6 @@ static void init_model(struct my_llama_model * model) { model->data.resize(size + tensor_alignment); alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment); alloc_model(alloc, model); - ggml_allocr_free(alloc); } static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { @@ -1102,7 +1101,6 @@ int main(int argc, char ** argv) { alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); ggml_allocr_alloc(alloc, tokens_input); ggml_allocr_alloc(alloc, target_probs); - ggml_allocr_free(alloc); // context for compute tensors without their data const size_t estimated_compute_size_wo_data = ( @@ -1149,7 +1147,6 @@ int main(int argc, char ** argv) { best_compute_size = max_compute_size; best_order = gf->order; } - ggml_allocr_free(alloc); ggml_free(ctx_compute); } size_t max_compute_size = best_compute_size; @@ -1177,7 +1174,6 @@ int main(int argc, char ** argv) { params.common.use_flash, params.common.use_checkpointing ); - ggml_allocr_free(alloc); std::vector train_tokens; std::vector train_samples_begin; From cca894f16a5eade15afd07b015e4cb3d8658943f Mon Sep 17 00:00:00 2001 From: Kylin <56434533+KyL0N@users.noreply.github.com> Date: Sat, 20 Jan 2024 15:01:46 +0800 Subject: [PATCH 18/28] cuda : fix compile error in jetson platform (#4975) * cuda: fix compile error in jetson platform * cuda: update comment in ggml-cuda.cu * cuda: update ggml-cuda.cu comment --- ggml-cuda.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b2211d858..ec3837fb8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -12,9 +12,6 @@ #include #include #include -#include "ggml-cuda.h" -#include "ggml.h" -#include "ggml-backend-impl.h" #if defined(GGML_USE_HIPBLAS) #include @@ -118,6 +115,11 @@ #endif // defined(GGML_USE_HIPBLAS) +// ggml-cuda need half type so keep ggml headers include at last +#include "ggml-cuda.h" +#include "ggml.h" +#include "ggml-backend-impl.h" + #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 19/28] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 263 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 136 +++++++++++-------- tests/test-backend-ops.cpp | 14 +- 6 files changed, 231 insertions(+), 198 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36..4d85dd3dd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794c..a1e1755a3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce..e64a328fa 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } - - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - assert(sum > 0.0); + float s; - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } -#endif + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c..7bca02f2a 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef..4e6c9f9cc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,7 +4356,60 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); + + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + // split cached v into n_head heads (transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, @@ -4352,59 +4418,15 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; - - struct ggml_tensor * kqv; - - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197..a56c0d6c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 20/28] metal : initial implementation --- ggml-metal.m | 69 +++++++++++++------ ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 180 insertions(+), 33 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3dd..556c53482 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 4; + const int nwarps = 1; - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3..5986bcb42 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fa..10df03c9c 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c5..51a33c662 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 21/28] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482..e67a7c4ef 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb42..e4e89b5b3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 22/28] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3..f3a7efafa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; + M = max(M, s); - if (s > M) { - M = s; - ms = exp(Mold - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } - - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 23/28] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef..046643146 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa..d97952f2b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 24/28] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 175 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 140 insertions(+), 45 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146..0b1119c4e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b..789b19bad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - if (iq1 >= ne01) { + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; + + if (iq2 >= ne02) { return; } - const int64_t D4 = D/4; - - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); - - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; - } - - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; - const half Mold = M; + simdgroup_barrier(mem_flags::mem_threadgroup); - M = max(M, s); + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; - const half ms = exp(Mold - M); - const half vs = exp(s - M); + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + const half Mold = M; + + M = max(M, s); + + const half ms = exp(Mold - M); + const half vs = exp(s - M); + + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 25/28] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4e..abb96d6ec 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad..6fdd7fdad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 26/28] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec..d521df43a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad..c9876c103 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 27/28] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43a..a60dd779a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c103..539e26c91 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 28/28] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a..fdfb50d3d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91..919119c8d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0,