Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cpp into flash-attn-cuda

This commit is contained in:
FSSRepo 2024-01-23 13:51:59 -05:00
commit a689b02ad3
20 changed files with 1437 additions and 512 deletions

View file

@ -846,7 +846,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
set(GGML_PUBLIC_HEADERS "ggml.h" set(GGML_PUBLIC_HEADERS "ggml.h" "ggml-alloc.h" "ggml-backend.h"
"${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}" "${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}"
"${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}") "${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}")

View file

@ -189,6 +189,8 @@ class Model:
return StableLMModel return StableLMModel
if model_architecture == "QWenLMHeadModel": if model_architecture == "QWenLMHeadModel":
return QwenModel return QwenModel
if model_architecture == "Qwen2ForCausalLM":
return Model
if model_architecture == "MixtralForCausalLM": if model_architecture == "MixtralForCausalLM":
return MixtralModel return MixtralModel
if model_architecture == "GPT2LMHeadModel": if model_architecture == "GPT2LMHeadModel":
@ -197,6 +199,8 @@ class Model:
return Phi2Model return Phi2Model
if model_architecture == "PlamoForCausalLM": if model_architecture == "PlamoForCausalLM":
return PlamoModel return PlamoModel
if model_architecture == "CodeShellForCausalLM":
return CodeShellModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -234,6 +238,8 @@ class Model:
return gguf.MODEL_ARCH.STABLELM return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel": if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN return gguf.MODEL_ARCH.QWEN
if arch == "Qwen2ForCausalLM":
return gguf.MODEL_ARCH.QWEN2
if arch == "MixtralForCausalLM": if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA return gguf.MODEL_ARCH.LLAMA
if arch == "GPT2LMHeadModel": if arch == "GPT2LMHeadModel":
@ -242,6 +248,8 @@ class Model:
return gguf.MODEL_ARCH.PHI2 return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM": if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO return gguf.MODEL_ARCH.PLAMO
if arch == "CodeShellForCausalLM":
return gguf.MODEL_ARCH.CODESHELL
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -1176,6 +1184,70 @@ class PlamoModel(Model):
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
class CodeShellModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_name("CodeShell")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(10000.0)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
tensors = dict(self.get_tensors())
has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
for name, data_torch in tensors.items():
# we don't need these
if name.endswith((".attn.rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
if not has_lm_head and name == "transformer.wte.weight":
self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View file

@ -348,7 +348,7 @@ class Params:
f_rope_freq_base = 1e6 f_rope_freq_base = 1e6
return Params( return Params(
n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), n_vocab=model["tok_embeddings.weight"].shape[0],
n_embd=config["dim"], n_embd=config["dim"],
n_layer=config["n_layers"], n_layer=config["n_layers"],
n_ctx=n_ctx, n_ctx=n_ctx,

View file

@ -0,0 +1,32 @@
# llama.cpp/examples/imatrix
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantum models.
More information is available here: https://github.com/ggerganov/llama.cpp/pull/4861
## Usage
```
./imatrix -m <some_fp_model> -f <some_training_data> [-o <output_file>] [--verbosity <verbosity_level>]
[-ofreq num_chunks] [-ow <0 or 1>] [other common params]
```
Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory.
The parameters in square brackets are optional and have the following meaning:
* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used.
* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`.
* `-ofreq` (or `--output-frequency`) specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks)
* `-ow` (or `--output-weight`) specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default.
For faster computation, make sure to use GPU offloading via the `-ngl` argument
## Example
```bash
LLAMA_CUBLAS=1 make -j
# generate importance matrix (imatrix.dat)
./imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99
# use the imatrix to perform a Q4_K_M quantization
./quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m
```

View file

@ -80,7 +80,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
// for simplicity, always copy src0 to host, because it is small // for simplicity, always copy src0 to host, because it is small
// take into account that src0 is not contiguous! // take into account that src0 is not contiguous!
GGML_ASSERT(src0->ne[1] == src1->ne[1]); GGML_ASSERT(src0->ne[1] == src1->ne[1]);
GGML_ASSERT(n_as*ggml_nrows(src0)); GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int)));
m_ids.resize(ggml_nbytes(src0)/sizeof(int)); m_ids.resize(ggml_nbytes(src0)/sizeof(int));
ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0)); ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0));

View file

@ -8,6 +8,7 @@
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <atomic>
#include <vector> #include <vector>
#include <array> #include <array>
#include <fstream> #include <fstream>
@ -324,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll = 0.0; double nll = 0.0;
double nll2 = 0.0; double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@ -332,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int start = i * n_ctx; const int start = i * n_ctx;
const int end = start + n_ctx; const int end = start + n_ctx;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
@ -361,9 +365,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS // restore the original token in case it was set to BOS
tokens[batch_start] = token_org; tokens[batch_start] = token_org;
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
} }
}
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
@ -391,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// last 256 tokens. Then, we split the input up into context window size chunks to // last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt. // process the entire prompt.
const int first = n_ctx/2; const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1; count += n_ctx - first - 1;
@ -405,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
} }
fflush(stdout); fflush(stdout);
logits.clear();
} }
printf("\n"); printf("\n");
@ -422,26 +431,73 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return {tokens, ppl, logit_history, prob_history}; return {tokens, ppl, logit_history, prob_history};
} }
static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens, static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
int n_past, int n_batch, int n_vocab) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
std::vector<float> result; const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
result.reserve(tokens.size() * n_vocab);
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; llama_batch batch_view = {
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { n_tokens,
size_t n_tokens = tokens.size() - i_chunk * n_batch; batch.token + i,
n_tokens = std::min(n_tokens, size_t(n_batch)); nullptr,
llama_kv_cache_seq_rm(ctx, 0, n_past, -1); batch.pos + i,
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) { batch.n_seq_id + i,
fprintf(stderr, "%s : failed to eval\n", __func__); batch.seq_id + i,
return {}; batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
} }
const auto logits = llama_get_logits(ctx); memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
n_past += n_tokens;
} }
return result;
return true;
}
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
constexpr int k_token_chunk = 4;
if (eval_results.size() != eval_pairs.size()) {
eval_results.resize(eval_pairs.size());
}
if (eval_pairs.empty()) return;
size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
std::atomic<int> counter(0);
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
float local_logprobs[k_token_chunk];
while (true) {
size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
if (first >= eval_results.size()) break;
size_t last = std::min(first + k_token_chunk, eval_results.size());
for (size_t i = first; i < last; ++i) {
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
float max_logit = logits[0];
for (int j = 1; j < n_vocab; ++j) {
max_logit = std::max(max_logit, logits[j]);
}
float sum_p = 0.f;
for (int j = 0; j < n_vocab; ++j) {
sum_p += expf(logits[j] - max_logit);
}
local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
}
std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
}
};
for (size_t it = 0; it < max_threads; ++it) {
workers[it] = std::thread(compute);
}
for (size_t it = 0; it < max_threads; ++it) {
workers[it].join();
}
} }
static void hellaswag_score(llama_context * ctx, const gpt_params & params) { static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
@ -533,7 +589,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// determine the common prefix of the endings // determine the common prefix of the endings
hs_cur.common_prefix = 0; hs_cur.common_prefix = 0;
hs_cur.required_tokens = 0;
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
@ -566,40 +621,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int max_tasks_per_batch = params.n_parallel; const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch; const int max_seq = 4*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_ctx*n_vocab); std::vector<float> batch_logits(n_vocab*n_ctx);
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { std::vector<std::pair<size_t, llama_token>> eval_pairs;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { std::vector<float> eval_results;
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); std::vector<std::thread> workers(std::thread::hardware_concurrency());
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}
return true;
};
for (size_t i0 = 0; i0 < hs_task_count; i0++) { for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0; int n_cur = 0;
@ -649,11 +681,29 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, n_batch)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__); fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return; return;
} }
// Compute log-probs in parallel
// First we collect all tasks
eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
size_t li = hs_cur.common_prefix;
for (int s = 0; s < 4; ++s) {
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]));
}
++li;
}
}
// Then we do the actual calculation
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
size_t ir = 0;
// compute the logprobs for each ending of the decoded tasks // compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) { for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i]; auto & hs_cur = hs_data[i];
@ -662,26 +712,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const auto first_probs = softmax(tok_logits); const auto first_probs = softmax(tok_logits);
size_t li = hs_cur.common_prefix; // logits index in the batch
for (int s = 0; s < 4; ++s) { for (int s = 0; s < 4; ++s) {
hs_cur.ending_logprob_count[s] = 1; hs_cur.ending_logprob_count[s] = 1;
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
// Calculate the logprobs over the ending
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); hs_cur.ending_logprob[s] += eval_results[ir++];
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
hs_cur.ending_logprob[s] += std::log(prob);
hs_cur.ending_logprob_count[s]++; hs_cur.ending_logprob_count[s]++;
} }
// account that we skip the last token in the ending
++li;
// Calculate the mean token logprob for acc_norm
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
} }
@ -720,6 +757,13 @@ struct winogrande_entry {
std::string second; std::string second;
std::array<std::string, 2> choices; std::array<std::string, 2> choices;
int answer; int answer;
size_t i_batch;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
size_t n_base2; // number of tokens for context + choice 2
std::vector<llama_token> seq_tokens[2];
}; };
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) { static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
@ -813,7 +857,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
} }
float scale = 1/(1.f + (float)rng.max()); float scale = 1/(1.f + (float)rng.max());
std::vector<winogrande_entry> selected; std::vector<winogrande_entry> selected;
selected.reserve(params.winogrande_tasks); selected.resize(params.winogrande_tasks);
for (int i = 0; i < int(params.winogrande_tasks); ++i) { for (int i = 0; i < int(params.winogrande_tasks); ++i) {
int j = int(scale*rng()*aux.size()); int j = int(scale*rng()*aux.size());
selected[i] = std::move(data[aux[j]]); selected[i] = std::move(data[aux[j]]);
@ -823,104 +867,145 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
data = std::move(selected); data = std::move(selected);
} }
fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
for (auto & task : data) {
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
task.common_prefix = 0;
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
break;
}
task.common_prefix++;
}
task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;
task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
}
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
std::vector<std::thread> workers(std::thread::hardware_concurrency());
int n_correct = 0; int n_correct = 0;
int n_done = 0; int n_done = 0;
for (size_t task_idx = 0; task_idx < data.size(); task_idx++) { for (size_t i0 = 0; i0 < data.size(); i0++) {
const auto& task = data[task_idx]; int n_cur = 0;
auto base_context = ::llama_tokenize(ctx, task.first, add_bos); size_t i1 = i0;
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos); size_t i_batch = 0;
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
auto sentence_1st = task.first + task.choices[0] + task.second; llama_batch_clear(batch);
auto sentence_2nd = task.first + task.choices[1] + task.second;
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) { while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size()); const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
}
batch.logits[batch.n_tokens - 1] = true;
for (int s = 0; s < 2; ++s) {
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
}
}
data[i1].i_batch = i_batch;
i_batch += data[i1].required_tokens;
n_cur += data[i1].required_tokens;
if (++i1 == data.size()) {
break;
}
}
if (i0 == i1) {
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
return; return;
} }
auto query_1st_size = query_1st.size();
auto query_2nd_size = query_2nd.size();
// Speedup small evaluations by evaluating atleast 32 tokens
// For Winogrande this seems to slow it down rather than speed it up.
//if (query_1st.size() < 32) query_1st.resize(32);
//if (query_2nd.size() < 32) query_2nd.resize(32);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
llama_kv_cache_clear(ctx); // decode all tasks [i0, i1)
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab); if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
if (logits_1st.empty() || logits_2nd.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return; return;
} }
bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx && eval_pairs.clear();
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx; for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];
const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
size_t li = n_base1 - 1;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1]));
}
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1]));
}
}
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
size_t ir = 0;
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];
const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
float score_1st = 0; float score_1st = 0;
bool is_nan_1st = false; const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const auto& base_1 = skip_choice ? base_ctx_1st : base_context; const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0; for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) { score_1st += eval_results[ir++];
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_1st[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_1st.c_str(), base_context.size());
is_nan_1st = true;
break;
} }
score_1st += std::log(prob); score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
}
score_1st /= (query_1st_size - base_1.size() - last_1st);
float score_2nd = 0; float score_2nd = 0;
bool is_nan_2nd = false; const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context; const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0; for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) { score_2nd += eval_results[ir++];
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_2nd[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_2nd.c_str(), base_context.size());
is_nan_2nd = true;
break;
}
score_2nd += std::log(prob);
}
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
if (is_nan_1st || is_nan_2nd) {
continue;
}
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
continue;
} }
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
int result = score_1st > score_2nd ? 1 : 2; int result = score_1st > score_2nd ? 1 : 2;
@ -929,11 +1014,14 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
} }
++n_done; ++n_done;
// Print the accumulated accuracy mean x 100 // print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer); printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
fflush(stdout); fflush(stdout);
} }
i0 = i1 - 1;
}
printf("\n"); printf("\n");
if (n_done < 100) return; if (n_done < 100) return;

View file

@ -1558,6 +1558,7 @@ struct llama_server_context
void process_tasks() void process_tasks()
{ {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
std::vector<task_server> deferred_tasks;
while (!queue_tasks.empty()) while (!queue_tasks.empty())
{ {
task_server task = queue_tasks.front(); task_server task = queue_tasks.front();
@ -1568,9 +1569,8 @@ struct llama_server_context
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr) if (slot == nullptr)
{ {
LOG_TEE("slot unavailable\n"); // if no slot is available, we defer this task for processing later
// send error result deferred_tasks.push_back(task);
send_error(task, "slot unavailable");
break; break;
} }
@ -1616,6 +1616,12 @@ struct llama_server_context
} }
} }
// add all the deferred tasks back the the queue
for (task_server &task : deferred_tasks)
{
queue_tasks.push_back(task);
}
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
std::vector<task_result> agg_results; std::vector<task_result> agg_results;
auto queue_iterator = queue_multitasks.begin(); auto queue_iterator = queue_multitasks.begin();

View file

@ -263,7 +263,6 @@ static void init_model(struct my_llama_model * model) {
model->data.resize(size + tensor_alignment); model->data.resize(size + tensor_alignment);
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment); alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
alloc_model(alloc, model); alloc_model(alloc, model);
ggml_allocr_free(alloc);
} }
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
@ -1102,7 +1101,6 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
ggml_allocr_alloc(alloc, tokens_input); ggml_allocr_alloc(alloc, tokens_input);
ggml_allocr_alloc(alloc, target_probs); ggml_allocr_alloc(alloc, target_probs);
ggml_allocr_free(alloc);
// context for compute tensors without their data // context for compute tensors without their data
const size_t estimated_compute_size_wo_data = ( const size_t estimated_compute_size_wo_data = (
@ -1149,7 +1147,6 @@ int main(int argc, char ** argv) {
best_compute_size = max_compute_size; best_compute_size = max_compute_size;
best_order = gf->order; best_order = gf->order;
} }
ggml_allocr_free(alloc);
ggml_free(ctx_compute); ggml_free(ctx_compute);
} }
size_t max_compute_size = best_compute_size; size_t max_compute_size = best_compute_size;
@ -1177,7 +1174,6 @@ int main(int argc, char ** argv) {
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing
); );
ggml_allocr_free(alloc);
std::vector<llama_token> train_tokens; std::vector<llama_token> train_tokens;
std::vector<size_t> train_samples_begin; std::vector<size_t> train_samples_begin;

View file

@ -12,9 +12,6 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <array> #include <array>
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
@ -118,6 +115,11 @@
#endif // defined(GGML_USE_HIPBLAS) #endif // defined(GGML_USE_HIPBLAS)
// ggml-cuda need half type so keep ggml headers include at last
#include "ggml-cuda.h"
#include "ggml.h"
#include "ggml-backend-impl.h"
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600 #define CC_PASCAL 600

View file

@ -147,7 +147,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -278,6 +280,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
NSURL * libURL = [NSURL fileURLWithPath:libPath]; NSURL * libURL = [NSURL fileURLWithPath:libPath];
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
} else { } else {
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
@ -316,14 +322,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
//[options setFastMathEnabled:false]; //[options setFastMathEnabled:false];
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
}
}
if (error) { if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL; return NULL;
} }
} }
}
}
// print MTL GPU family: // print MTL GPU family:
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
@ -396,6 +401,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
if (error) { \ if (error) { \
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \ return NULL; \
@ -512,7 +520,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@ -2166,21 +2176,50 @@ static bool ggml_metal_graph_compute(
} break; } break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * src3 = gf->nodes[i]->src[3]; struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape(src1, src2));
size_t offs_src2 = 0; size_t offs_src2 = 0;
size_t offs_src3 = 0; size_t offs_src3 = 0;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; GGML_ASSERT(src2);
id<MTLBuffer> id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2);
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil;
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); memcpy(&scale, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; id<MTLComputePipelineState> pipeline = nil;
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
default:
{
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
GGML_ASSERT(false && "add template specialization for this size");
}
}
// TODO: extend if necessary // TODO: extend if necessary
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -2197,19 +2236,31 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&scale length:sizeof( float) atIndex:21]; [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int nth = MIN(1024, ne0); const int64_t nwarps = 32;
const int64_t nhptg = 2; // heads per threadgroup
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:

View file

@ -1959,11 +1959,11 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
} }
kernel void kernel_flash_attn_ext_f16( typedef void (flash_attn_ext_f16_t)(
device const half * q, device const char * q,
device const half * k, device const char * k,
device const half * v, device const char * v,
device const float * mask, device const char * mask,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
@ -1973,21 +1973,237 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant uint64_t & nb03, constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant int64_t & ne2, constant int64_t & ne2,
constant int64_t & ne3, constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & scale, constant float & scale,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]],
// TODO: implement uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);
template<int64_t D, int64_t R> // head size, rows per threadgroup
kernel void kernel_flash_attn_ext_f16(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const uint tph = N_SIMDWIDTH/R; // threads per head
const int64_t iq3 = tgpig[2];
const int64_t iq2 = tgpig[1]*R + tiisg/tph;
const int64_t iq1 = tgpig[0];
if (iq2 >= ne02) {
return;
}
// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;
const uint64_t nb21 = nb11;
const uint64_t nb22 = nb12;
const uint64_t nb23 = nb13;
// broadcast
const int64_t rk2 = ne02/ne12;
const int64_t rk3 = ne03/ne13;
const int64_t rv2 = ne02/ne22;
const int64_t rv3 = ne03/ne23;
// k indices
const int64_t ik2 = iq2 / rk2;
const int64_t ik3 = iq3 / rk3;
// v indices
const int64_t iv2 = iq2 / rv2;
const int64_t iv3 = iq3 / rv3;
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
const int64_t D4 = D/4;
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D);
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D);
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D);
const uint tiih = tiisg%tph; // thread index in head
const uint hiisg = tiisg/tph; // head index in simdgroup
// load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) {
if (sgitg == 0) {
pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
}
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
half S = 0.0h;
half M = -INFINITY;
for (int64_t ic = sgitg; ic < ne11; ic += nsg) {
const half mv = mp ? mp[ic] : 0.0h;
if (mv == -INFINITY) {
continue;
}
device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
half4 s4 = 0.0h;
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
}
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
simdgroup_barrier(mem_flags::mem_threadgroup);
if (tiih == 0) {
half s = 0.0h;
#pragma unroll
for (int64_t i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i];
}
s = s*scale + mv;
const half m = M;
M = max(M, s);
const half ms = exp(m - M);
const half vs = exp(s - M);
S = S*ms + vs;
ss[2*hiisg + 0] = ms;
ss[2*hiisg + 1] = vs;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1];
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
}
}
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce the warps
if (sgitg == 0) {
for (int64_t sg = 1; sg < nsg; ++sg) {
const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
M = max(M0, M1);
const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1;
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
}
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
}
}
simdgroup_barrier(mem_flags::mem_threadgroup);
// dst indices
const int64_t i1 = iq1;
const int64_t i2 = iq2;
const int64_t i3 = iq3;
device float4 * dst4 = (device float4 *) dst;
if (sgitg == 0) {
for (int64_t i = 0; i < D4/tph; ++i) {
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih];
}
}
} }
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,

257
ggml.c
View file

@ -817,7 +817,7 @@ do { \
#if defined(__F16C__) #if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C // the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else #else
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif #endif
} }
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v // xs and vs are byte strides of x and v
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#endif #endif
} }
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext(
is_node = true; is_node = true;
} }
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q); // permute(0, 2, 1, 3)
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
float params[] = { scale }; float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t D = neq0; const int64_t D = neq0;
const int64_t N = neq1; const int64_t N = neq1;
const int64_t P = nek1 - N; const int64_t P = nek1 - N;
const int64_t M = P + N;
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne1 == N); GGML_ASSERT(ne2 == N);
GGML_ASSERT(P >= 0); GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D); GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nek1 == N + P);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16(
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices
const int iq3 = ir/(neq2*neq1); const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); float S = 0.0f;
float M = -INFINITY;
for (int i = M; i < Mup; ++i) { float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
S[i] = -INFINITY; ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
}
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { memset(V16, 0, D*sizeof(ggml_fp16_t));
const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) { for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices const float mv = mp ? mp[ic] : 0.0f;
const int ik3 = iq3 / rk3; if (mv == -INFINITY) {
const int ik2 = iq2 / rk2; continue;
const int ik1 = ic; }
// S indices float s;
const int i1 = ik1;
ggml_vec_dot_f16(neq0, ggml_vec_dot_f16(D,
S + i1, &s,
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
s = s*scale + mv;
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, V16, ms);
} else { } else {
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { vs = expf(s - M);
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_vec_dot_f16_unroll(neq0, nbk1,
S + i1,
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
} }
// scale const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_scale_f32(nek1, S, scale);
if (mask) { // V += v*expf(s - M)
const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); ggml_vec_mad_f16(D, V16, v16, vs);
ggml_vec_acc_f32(M, S, mp);
S = S*ms + vs;
} }
// softmax // V /= S
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. for (int64_t d = 0; d < D; ++d) {
// dont forget to set their S values to zero V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
ggml_float sum = 0.0;
{
#ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
sump[j] += (ggml_float)val;
SS[j] = val;
}
}
} }
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(M, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
for (int64_t i = 0; i < M; i++) {
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices // dst indices
const int i1 = iq1; const int i1 = iq1;
const int i2 = iq2; const int i2 = iq2;
const int i3 = iq3; const int i3 = iq3;
// v indices // original
const int iv2 = iq2 / rv2; //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16(nev0, // permute(0, 2, 1, 3)
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
ggml_vec_dot_f16_unroll(nev0, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
} }
} }
@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break; } break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
} }
} break; } break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
} break;
case GGML_OP_FLASH_FF: case GGML_OP_FLASH_FF:
{ {
if (node->src[1]->type == GGML_TYPE_F32) { if (node->src[1]->type == GGML_TYPE_F32) {

5
ggml.h
View file

@ -1620,6 +1620,11 @@ extern "C" {
struct ggml_tensor * v, struct ggml_tensor * v,
bool masked); bool masked);
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch, 1, 1]
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext( GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,

View file

@ -97,8 +97,10 @@ class MODEL_ARCH(IntEnum):
BLOOM = auto() BLOOM = auto()
STABLELM = auto() STABLELM = auto()
QWEN = auto() QWEN = auto()
QWEN2 = auto()
PHI2 = auto() PHI2 = auto()
PLAMO = auto() PLAMO = auto()
CODESHELL = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -145,8 +147,10 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -356,6 +360,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.QWEN2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PLAMO: [ MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
@ -396,6 +414,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.CODESHELL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
] ]
# TODO # TODO
} }
@ -417,6 +448,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
], ],
MODEL_ARCH.CODESHELL: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
} }
# #

View file

@ -154,6 +154,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
), ),
# Feed-forward norm # Feed-forward norm

453
llama.cpp
View file

@ -95,6 +95,8 @@
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 8 #define LLAMA_MAX_EXPERTS 8
#define LLAMA_FLASH_ATTN
// //
// logging // logging
// //
@ -192,8 +194,10 @@ enum llm_arch {
LLM_ARCH_BLOOM, LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM, LLM_ARCH_STABLELM,
LLM_ARCH_QWEN, LLM_ARCH_QWEN,
LLM_ARCH_QWEN2,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PLAMO, LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -211,8 +215,10 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_CODESHELL, "codeshell" },
}; };
enum llm_kv { enum llm_kv {
@ -566,6 +572,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_QWEN2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
{ {
@ -600,6 +623,26 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_CODESHELL,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
@ -1599,7 +1642,7 @@ struct llama_model {
std::unique_ptr<llama_mmap> mapping; std::unique_ptr<llama_mmap> mapping;
// objects representing data potentially being locked in memory // objects representing data potentially being locked in memory
llama_mlock mlock_buf; std::vector<std::unique_ptr<llama_mlock>> mlock_bufs;
llama_mlock mlock_mmap; llama_mlock mlock_mmap;
// for quantize-stats only // for quantize-stats only
@ -2847,6 +2890,17 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_QWEN2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
case 80: model.type = e_model::MODEL_70B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_PHI2: case LLM_ARCH_PHI2:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -2877,6 +2931,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_CODESHELL:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 42: model.type = e_model::MODEL_SMALL; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -3438,7 +3500,12 @@ static bool llm_load_tensors(
{ {
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_OUTPUT, "weight").c_str()) >= 0) {
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
} else {
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU
ml.n_created--; // artificial tensor
}
} }
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
@ -3669,6 +3736,41 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2});
} }
} break; } break;
case LLM_ARCH_QWEN2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
// optional bias tensors
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
case LLM_ARCH_PHI2: case LLM_ARCH_PHI2:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -3754,6 +3856,42 @@ static bool llm_load_tensors(
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
} }
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
}
} break;
case LLM_ARCH_CODESHELL:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i); ggml_context * ctx_split = ctx_for_layer_split(i);
@ -3815,8 +3953,10 @@ static bool llm_load_tensors(
else { else {
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) { if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) {
model.mlock_buf.init (ggml_backend_buffer_get_base(buf)); model.mlock_bufs.emplace_back(new llama_mlock);
model.mlock_buf.grow_to(ggml_backend_buffer_get_size(buf)); auto & mlock_buf = model.mlock_bufs.back();
mlock_buf->init (ggml_backend_buffer_get_base(buf));
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
} }
} }
if (buf == nullptr) { if (buf == nullptr) {
@ -4029,23 +4169,34 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il); cb(k_cache_view, "k_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
#if defined(LLAMA_FLASH_ATTN)
// NOTE: the V cache is not transposed when using FLASH attention !!
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
GGML_UNUSED(n_ctx);
#else
// compute the transposed [n_tokens, n_embd] V matrix
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]), ( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il])); (kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
#endif
} }
static struct ggml_tensor * llm_build_norm( static struct ggml_tensor * llm_build_norm(
@ -4205,27 +4356,27 @@ static struct ggml_tensor * llm_build_kqv(
0); 0);
cb(k, "k", il); cb(k, "k", il);
// split cached v into n_head heads struct ggml_tensor * cur;
#if defined(LLAMA_FLASH_ATTN)
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v = struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il], ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv, n_embd_head_v, n_kv, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx, ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
0); 0);
cb(v, "v", il); cb(v, "v", il);
// TODO: determine if we can use flash attention cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
const bool supports_flash_attn = true;
struct ggml_tensor * kqv;
if (supports_flash_attn) {
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
} else {
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
#else
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
@ -4258,15 +4409,24 @@ static struct ggml_tensor * llm_build_kqv(
cb(kq, "kq_soft_max_ext", il); cb(kq, "kq_soft_max_ext", il);
} }
kqv = ggml_mul_mat(ctx, v, kq); // split cached v into n_head heads (transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il); cb(kqv, "kqv", il);
}
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il); cb(kqv_merged, "kqv_merged", il);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il); cb(cur, "kqv_merged_cont", il);
#endif
cur = ggml_mul_mat(ctx, wo, cur); cur = ggml_mul_mat(ctx, wo, cur);
if (wo_b) { if (wo_b) {
@ -5638,6 +5798,128 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_qwen2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_phi2() { struct ggml_cgraph * build_phi2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -5971,6 +6253,117 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_codeshell() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(tmpq, "tmpq", il);
cb(tmpk, "tmpk", il);
cb(Vcur, "Vcur", il);
struct ggml_tensor * Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
// FF
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm,
model.layers[il].ffn_norm_b,
LLM_NORM, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il);
}
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
}
cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm,
model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
}; };
static struct ggml_cgraph * llama_build_graph( static struct ggml_cgraph * llama_build_graph(
@ -6153,6 +6546,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_qwen(); result = llm.build_qwen();
} break; } break;
case LLM_ARCH_QWEN2:
{
result = llm.build_qwen2();
} break;
case LLM_ARCH_PHI2: case LLM_ARCH_PHI2:
{ {
result = llm.build_phi2(); result = llm.build_phi2();
@ -6165,6 +6562,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_gpt2(); result = llm.build_gpt2();
} break; } break;
case LLM_ARCH_CODESHELL:
{
result = llm.build_codeshell();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }

View file

@ -4,7 +4,7 @@ wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag
echo "Usage:" echo "Usage:"
echo "" echo ""
echo " ./perplexity --hellaswag --hellaswag-tasks N -f hellaswag_val_full.txt -m modelfile.gguf" echo " ./perplexity -m model.gguf -f hellaswag_val_full.txt --hellaswag [--hellaswag-tasks N] [other params]"
echo "" echo ""
exit 0 exit 0

View file

@ -1,3 +1,10 @@
#!/bin/bash #!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
echo "Usage:"
echo ""
echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]"
echo ""
exit 0

10
scripts/get-winogrande.sh Executable file
View file

@ -0,0 +1,10 @@
#!/bin/bash
wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv
echo "Usage:"
echo ""
echo " ./perplexity -m model.gguf -f winogrande-debiased-eval.csv --winogrande [--winogrande-tasks N] [other params]"
echo ""
exit 0

View file

@ -1390,21 +1390,25 @@ struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size const int64_t hs; // head size
const int64_t nh; // num heads const int64_t nh; // num heads
const int64_t kv; // kv size const int64_t kv; // kv size
const int64_t nt; // tokens const int64_t nb; // batch size
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(typeq, hs, nh, kv, nt); return VARS_TO_STR5(typeq, hs, nh, kv, nb);
}
double max_nmse_err() override {
return 5e-4;
} }
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
return out; return out;
} }