llama : per-layer KV cache + quantum K cache (#4309)

* per-layer KV

* remove unnecessary copies

* less code duplication, offload k and v separately

* llama : offload KV cache per-layer

* llama : offload K shift tensors

* llama : offload for rest of the model arches

* llama : enable offload debug temporarily

* llama : keep the KV related layers on the device

* llama : remove mirrors, perform Device -> Host when partial offload

* common : add command-line arg to disable KV cache offloading

* llama : update session save/load

* llama : support quantum K cache (#4312)

* llama : support quantum K cache (wip)

* metal : add F32 -> Q8_0 copy kernel

* cuda : add F32 -> Q8_0 copy kernel

ggml-ci

* cuda : use mmv kernel for quantum cache ops

* llama : pass KV cache type through API

* llama : fix build

ggml-ci

* metal : add F32 -> Q4_0 copy kernel

* metal : add F32 -> Q4_1 copy kernel

* cuda : wip

* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels

* llama-bench : support type_k/type_v

* metal : use mm kernel only for quantum KV cache

* cuda : add comment

* llama : remove memory_f16 and kv_f16 flags

---------

Co-authored-by: slaren <slarengh@gmail.com>

* readme : add API change notice

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-12-07 13:03:17 +02:00 committed by GitHub
parent 81bc9214a3
commit bcc0eb4591
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 747 additions and 287 deletions

440
llama.cpp
View file

@ -1231,6 +1231,7 @@ struct llama_cparams {
float yarn_beta_slow;
bool mul_mat_q;
bool offload_kqv;
};
struct llama_layer {
@ -1299,8 +1300,8 @@ struct llama_kv_cache {
std::vector<llama_kv_cell> cells;
struct ggml_tensor * k = NULL;
struct ggml_tensor * v = NULL;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
struct ggml_context * ctx = NULL;
@ -1313,8 +1314,10 @@ struct llama_kv_cache {
#ifdef GGML_USE_CUBLAS
if (ggml_cublas_loaded()) {
ggml_cuda_free_data(k);
ggml_cuda_free_data(v);
for (size_t i = 0; i < k_l.size(); ++i) {
ggml_cuda_free_data(k_l[i]);
ggml_cuda_free_data(v_l[i]);
}
}
#endif
}
@ -1504,9 +1507,11 @@ struct llama_context {
static bool llama_kv_cache_init(
const struct llama_hparams & hparams,
struct llama_kv_cache & cache,
ggml_type wtype,
ggml_type ktype,
ggml_type vtype,
uint32_t n_ctx,
int n_gpu_layers) {
int n_gpu_layers,
bool offload) {
const uint32_t n_embd = hparams.n_embd_gqa();
const uint32_t n_layer = hparams.n_layer;
@ -1522,7 +1527,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(n_ctx);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params;
@ -1532,37 +1537,44 @@ static bool llama_kv_cache_init(
cache.ctx = ggml_init(params);
size_t vram_kv_cache = 0;
if (!cache.ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
ggml_set_name(cache.k, "cache_k");
ggml_set_name(cache.v, "cache_v");
cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);
(void) n_gpu_layers;
const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
GGML_UNUSED(offload);
for (int i = 0; i < (int) n_layer; i++) {
ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
#ifdef GGML_USE_CUBLAS
if (ggml_cublas_loaded()) {
size_t vram_kv_cache = 0;
if (n_gpu_layers > (int)n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
vram_kv_cache += ggml_nbytes(cache.v);
}
if (n_gpu_layers > (int)n_layer + 2) {
ggml_cuda_assign_buffers_no_scratch(cache.k);
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
vram_kv_cache += ggml_nbytes(cache.k);
}
if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
if (i >= i_gpu_start) {
if (offload) {
ggml_cuda_assign_buffers_no_scratch(k);
vram_kv_cache += ggml_nbytes(k);
ggml_cuda_assign_buffers_no_scratch(v);
vram_kv_cache += ggml_nbytes(v);
}
}
#endif // GGML_USE_CUBLAS
}
#endif
if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
}
GGML_UNUSED(n_gpu_layers);
return true;
}
@ -2968,14 +2980,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3045,14 +3050,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3115,14 +3113,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3192,14 +3183,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3269,21 +3253,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > int(n_layer + 1)) {
LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
__func__, n_layer + 1);
throw std::runtime_error("Persimmon CUDA offload failed");
}
#endif
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3342,14 +3312,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3420,14 +3383,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3487,14 +3443,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3559,14 +3508,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3642,8 +3584,8 @@ static void llm_load_tensors(
}
#ifdef GGML_USE_CUBLAS
const int max_backend_supported_layers = hparams.n_layer + 3;
const int max_offloadable_layers = hparams.n_layer + 3;
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
#elif GGML_USE_CLBLAST
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
@ -3811,11 +3753,11 @@ static void llm_build_k_shift(
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
0),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
@ -3842,13 +3784,13 @@ static void llm_build_kv_store(
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
(ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
(ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
cb(k_cache_view, "k_cache_view", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv.v),
(il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
@ -4000,11 +3942,11 @@ static struct ggml_tensor * llm_build_kqv(
cb(q, "q", il);
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
0);
cb(k, "k", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@ -4035,11 +3977,11 @@ static struct ggml_tensor * llm_build_kqv(
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v,
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv.v)*n_ctx,
ggml_element_size(kv.v)*n_ctx*n_embd_head,
ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@ -4631,6 +4573,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "imp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
@ -4638,6 +4581,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5237,15 +5181,15 @@ struct llm_build_context {
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it wil be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
@ -5351,8 +5295,8 @@ struct llm_build_context {
enum llm_offload_func_e {
OFFLOAD_FUNC_NOP,
OFFLOAD_FUNC,
OFFLOAD_FUNC_KQ,
OFFLOAD_FUNC_V,
OFFLOAD_FUNC_FRC, // force offload
OFFLOAD_FUNC_KQV,
OFFLOAD_FUNC_NR,
OFFLOAD_FUNC_EMB,
OFFLOAD_FUNC_OUT,
@ -5438,11 +5382,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
//{ "inp_embd", OFFLOAD_FUNC_NR }, // TODO: missing K-quants get_rows kernel
{ "pos_embd", OFFLOAD_FUNC_NR },
{ "inp_pos", OFFLOAD_FUNC_KQ }, // this is often used for KQ ops (e.g. rope)
{ "KQ_scale", OFFLOAD_FUNC_KQ },
{ "KQ_mask", OFFLOAD_FUNC_KQ },
{ "K_shift", OFFLOAD_FUNC_KQ },
{ "K_shifted", OFFLOAD_FUNC_KQ },
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
{ "KQ_scale", OFFLOAD_FUNC_FRC },
{ "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC },
{ "K_shifted", OFFLOAD_FUNC },
{ "inp_norm", OFFLOAD_FUNC_NR },
{ "inp_norm_w", OFFLOAD_FUNC_NR },
@ -5455,38 +5400,38 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "attn_norm", OFFLOAD_FUNC },
{ "attn_norm_2", OFFLOAD_FUNC },
{ "wqkv", OFFLOAD_FUNC_KQ },
{ "bqkv", OFFLOAD_FUNC_KQ },
{ "wqkv_clamped", OFFLOAD_FUNC_KQ },
{ "wqkv", OFFLOAD_FUNC_KQV },
{ "bqkv", OFFLOAD_FUNC_KQV },
{ "wqkv_clamped", OFFLOAD_FUNC_KQV },
{ "tmpk", OFFLOAD_FUNC_KQ },
{ "tmpq", OFFLOAD_FUNC_KQ },
{ "tmpv", OFFLOAD_FUNC_V },
{ "Kcur", OFFLOAD_FUNC_KQ },
{ "Qcur", OFFLOAD_FUNC_KQ },
{ "Vcur", OFFLOAD_FUNC_V },
{ "tmpk", OFFLOAD_FUNC_KQV },
{ "tmpq", OFFLOAD_FUNC_KQV },
{ "tmpv", OFFLOAD_FUNC_KQV },
{ "Kcur", OFFLOAD_FUNC_KQV },
{ "Qcur", OFFLOAD_FUNC_KQV },
{ "Vcur", OFFLOAD_FUNC_KQV },
{ "krot", OFFLOAD_FUNC_KQ },
{ "qrot", OFFLOAD_FUNC_KQ },
{ "kpass", OFFLOAD_FUNC_KQ },
{ "qpass", OFFLOAD_FUNC_KQ },
{ "krotated", OFFLOAD_FUNC_KQ },
{ "qrotated", OFFLOAD_FUNC_KQ },
{ "krot", OFFLOAD_FUNC_KQV },
{ "qrot", OFFLOAD_FUNC_KQV },
{ "kpass", OFFLOAD_FUNC_KQV },
{ "qpass", OFFLOAD_FUNC_KQV },
{ "krotated", OFFLOAD_FUNC_KQV },
{ "qrotated", OFFLOAD_FUNC_KQV },
{ "q", OFFLOAD_FUNC_KQ },
{ "k", OFFLOAD_FUNC_KQ },
{ "kq", OFFLOAD_FUNC_KQ },
{ "kq_scaled", OFFLOAD_FUNC_KQ },
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
{ "kq_masked", OFFLOAD_FUNC_KQ },
{ "kq_soft_max", OFFLOAD_FUNC_V },
{ "kq_soft_max_ext", OFFLOAD_FUNC_V },
{ "v", OFFLOAD_FUNC_V },
{ "kqv", OFFLOAD_FUNC_V },
{ "kqv_merged", OFFLOAD_FUNC_V },
{ "kqv_merged_cont", OFFLOAD_FUNC_V },
{ "kqv_wo", OFFLOAD_FUNC_V },
{ "kqv_out", OFFLOAD_FUNC_V },
{ "q", OFFLOAD_FUNC_KQV },
{ "k", OFFLOAD_FUNC_KQV },
{ "kq", OFFLOAD_FUNC_KQV },
{ "kq_scaled", OFFLOAD_FUNC_KQV },
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQV },
{ "kq_masked", OFFLOAD_FUNC_KQV },
{ "kq_soft_max", OFFLOAD_FUNC_KQV },
{ "kq_soft_max_ext", OFFLOAD_FUNC_KQV },
{ "v", OFFLOAD_FUNC_KQV },
{ "kqv", OFFLOAD_FUNC_KQV },
{ "kqv_merged", OFFLOAD_FUNC_KQV },
{ "kqv_merged_cont", OFFLOAD_FUNC_KQV },
{ "kqv_wo", OFFLOAD_FUNC_KQV },
{ "kqv_out", OFFLOAD_FUNC_KQV },
{ "ffn_inp", OFFLOAD_FUNC },
{ "ffn_norm", OFFLOAD_FUNC },
@ -5678,15 +5623,15 @@ static struct ggml_cgraph * llama_build_graph(
{ OFFLOAD_FUNC_NOP, "CPU" },
{ OFFLOAD_FUNC_OUT, "CPU" },
#ifdef GGML_USE_CUBLAS
{ OFFLOAD_FUNC, "GPU (CUDA)" },
{ OFFLOAD_FUNC_KQ, "GPU (CUDA) KQ" },
{ OFFLOAD_FUNC_V, "GPU (CUDA) V" },
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
{ OFFLOAD_FUNC, "GPU (CUDA)" },
{ OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" },
{ OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" },
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
{ OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
#else
{ OFFLOAD_FUNC, "CPU" },
{ OFFLOAD_FUNC_KQ, "CPU" },
{ OFFLOAD_FUNC_V, "CPU" },
{ OFFLOAD_FUNC_FRC, "CPU" },
{ OFFLOAD_FUNC_KQV, "CPU" },
{ OFFLOAD_FUNC_NR, "CPU" },
{ OFFLOAD_FUNC_EMB, "CPU" },
#endif // GGML_USE_CUBLAS
@ -5719,21 +5664,26 @@ static struct ggml_cgraph * llama_build_graph(
}
}
break;
case OFFLOAD_FUNC_FRC:
if (!lctx.cparams.offload_kqv) {
func_e = OFFLOAD_FUNC_NOP;
} break;
case OFFLOAD_FUNC_KQV:
if (!lctx.cparams.offload_kqv) {
func_e = OFFLOAD_FUNC_NOP;
} else {
if (n_gpu_layers < n_layer) {
if (il < i_gpu_start) {
func_e = OFFLOAD_FUNC_NOP;
}
}
}
break;
case OFFLOAD_FUNC_NR:
if (n_gpu_layers <= n_layer + 0) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_V:
if (n_gpu_layers <= n_layer + 1) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_KQ:
if (n_gpu_layers <= n_layer + 2) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_EMB:
if (!offload_emb || n_gpu_layers < n_layer) {
func_e = OFFLOAD_FUNC_NOP;
@ -5755,8 +5705,8 @@ static struct ggml_cgraph * llama_build_graph(
case OFFLOAD_FUNC_NOP:
case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
case OFFLOAD_FUNC:
case OFFLOAD_FUNC_KQ:
case OFFLOAD_FUNC_V:
case OFFLOAD_FUNC_KQV:
case OFFLOAD_FUNC_FRC:
case OFFLOAD_FUNC_NR:
case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break;
default: GGML_ASSERT(false);
@ -5942,6 +5892,7 @@ static int llama_decode_internal(
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@ -5992,7 +5943,7 @@ static int llama_decode_internal(
n_threads = std::min(4, n_threads);
}
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
if (ggml_cpu_has_cublas() && fully_offloaded) {
n_threads = 1;
}
@ -8821,10 +8772,12 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.mul_mat_q =*/ true,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
};
return result;
@ -8941,6 +8894,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@ -8974,19 +8928,36 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
const ggml_type type_k = params.type_k;
const ggml_type type_v = params.type_v;
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
// reserve memory for context buffers
if (!hparams.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// resized during inference
@ -9057,8 +9028,12 @@ struct llama_context * llama_new_context_with_model(
}
size_t kv_vram_size = 0;
add_tensor(ctx->kv_self.k, kv_vram_size);
add_tensor(ctx->kv_self.v, kv_vram_size);
for (auto & k : ctx->kv_self.k_l) {
add_tensor(k, kv_vram_size);
}
for (auto & v : ctx->kv_self.v_l) {
add_tensor(v, kv_vram_size);
}
size_t ctx_vram_size = alloc_size + kv_vram_size;
size_t total_vram_size = model_vram_size + ctx_vram_size;
@ -9528,37 +9503,45 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
kout3d->data = kout3d_data.data();
std::vector<std::vector<uint8_t>> kout2d_data(n_layer);
std::vector<std::vector<uint8_t>> vout2d_data(n_layer);
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
vout3d->data = vout3d_data.data();
for (int il = 0; il < (int) n_layer; ++il) {
ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kout2d_data[il].resize(ggml_nbytes(kout2d));
kout2d->data = kout2d_data[il].data();
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vout2d_data[il].resize(ggml_nbytes(vout2d));
vout2d->data = vout2d_data[il].data();
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
// our data is now in the kout3d_data and vout3d_data buffers
// our data is now in the kout2d_data and vout2d_data buffers
// write them to file
data_ctx->write(kout3d_data.data(), kout3d_data.size());
data_ctx->write(vout3d_data.data(), vout3d_data.size());
for (uint32_t il = 0; il < n_layer; ++il) {
data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size());
data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size());
}
}
for (uint32_t i = 0; i < kv_size; ++i) {
@ -9658,29 +9641,32 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
inp += ggml_nbytes(kin3d);
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kin2d->data = (void *) inp;
inp += ggml_nbytes(kin2d);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
vin3d->data = (void *) inp;
inp += ggml_nbytes(vin3d);
ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vin2d->data = (void *) inp;
inp += ggml_nbytes(vin2d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);