skip-unused: disable skipping on ROCm / when LLAMA_USE_HIPBLAS

This commit is contained in:
ochafik 2023-08-23 21:36:56 +01:00 committed by Olivier Chafik
parent 2cf4f62e12
commit 7ec7ef94a9

View file

@ -56,6 +56,12 @@
#include <stdio.h> // for _fseeki64 #include <stdio.h> // for _fseeki64
#endif #endif
// TODO: Fix unused logit skipping crashes on ROCm
// (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
#ifndef LLAMA_USE_HIPBLAS
#define LLAMA_SKIP_UNUSED_LOGITS
#endif
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cassert> #include <cassert>
@ -2277,6 +2283,7 @@ static struct ggml_cgraph * llm_build_llama(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
} }
#ifdef LLAMA_SKIP_UNUSED_LOGITS
if (il == n_layer - 1 && !lctx.logits_all) if (il == n_layer - 1 && !lctx.logits_all)
{ {
// From here on, we only care about the last token and its logits. // From here on, we only care about the last token and its logits.
@ -2297,6 +2304,7 @@ static struct ggml_cgraph * llm_build_llama(
n_past += N - 1; n_past += N - 1;
N = 1; N = 1;
} }
#endif // LLAMA_SKIP_UNUSED_LOGITS
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq); offload_func_kq(tmpq);
@ -2928,9 +2936,14 @@ static bool llama_eval_internal(
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else { } else {
// return result for just the last token // return result for just the last token
GGML_ASSERT(ggml_nelements(res) == n_vocab);
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
#ifdef LLAMA_SKIP_UNUSED_LOGITS
GGML_ASSERT(ggml_nelements(res) == n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
#else
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
#endif
} }
} }