From 7ec7ef94a92161d8cf65866e0666001ba9fc8ac9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 23 Aug 2023 21:36:56 +0100 Subject: [PATCH] skip-unused: disable skipping on ROCm / when LLAMA_USE_HIPBLAS --- llama.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 838f47e36..5692ced56 100644 --- a/llama.cpp +++ b/llama.cpp @@ -56,6 +56,12 @@ #include // for _fseeki64 #endif +// TODO: Fix unused logit skipping crashes on ROCm +// (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127) +#ifndef LLAMA_USE_HIPBLAS +#define LLAMA_SKIP_UNUSED_LOGITS +#endif + #include #include #include @@ -2277,6 +2283,7 @@ static struct ggml_cgraph * llm_build_llama( ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } +#ifdef LLAMA_SKIP_UNUSED_LOGITS if (il == n_layer - 1 && !lctx.logits_all) { // From here on, we only care about the last token and its logits. @@ -2297,6 +2304,7 @@ static struct ggml_cgraph * llm_build_llama( n_past += N - 1; N = 1; } +#endif // LLAMA_SKIP_UNUSED_LOGITS struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); offload_func_kq(tmpq); @@ -2928,9 +2936,14 @@ static bool llama_eval_internal( memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); } else { // return result for just the last token - GGML_ASSERT(ggml_nelements(res) == n_vocab); logits_out.resize(n_vocab); +#ifdef LLAMA_SKIP_UNUSED_LOGITS + GGML_ASSERT(ggml_nelements(res) == n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab); +#else + GGML_ASSERT(ggml_nelements(res) == n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); +#endif } }