skip-unused: disable skipping on ROCm / when LLAMA_USE_HIPBLAS
This commit is contained in:
parent
2cf4f62e12
commit
7ec7ef94a9
1 changed files with 14 additions and 1 deletions
15
llama.cpp
15
llama.cpp
|
@ -56,6 +56,12 @@
|
||||||
#include <stdio.h> // for _fseeki64
|
#include <stdio.h> // for _fseeki64
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// TODO: Fix unused logit skipping crashes on ROCm
|
||||||
|
// (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
|
||||||
|
#ifndef LLAMA_USE_HIPBLAS
|
||||||
|
#define LLAMA_SKIP_UNUSED_LOGITS
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
@ -2277,6 +2283,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef LLAMA_SKIP_UNUSED_LOGITS
|
||||||
if (il == n_layer - 1 && !lctx.logits_all)
|
if (il == n_layer - 1 && !lctx.logits_all)
|
||||||
{
|
{
|
||||||
// From here on, we only care about the last token and its logits.
|
// From here on, we only care about the last token and its logits.
|
||||||
|
@ -2297,6 +2304,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
n_past += N - 1;
|
n_past += N - 1;
|
||||||
N = 1;
|
N = 1;
|
||||||
}
|
}
|
||||||
|
#endif // LLAMA_SKIP_UNUSED_LOGITS
|
||||||
|
|
||||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
offload_func_kq(tmpq);
|
offload_func_kq(tmpq);
|
||||||
|
@ -2928,9 +2936,14 @@ static bool llama_eval_internal(
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
|
||||||
} else {
|
} else {
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
GGML_ASSERT(ggml_nelements(res) == n_vocab);
|
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
|
#ifdef LLAMA_SKIP_UNUSED_LOGITS
|
||||||
|
GGML_ASSERT(ggml_nelements(res) == n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
|
||||||
|
#else
|
||||||
|
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
|
||||||
|
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue