llama : flash_attn cparam + fix defrag
This commit is contained in:
parent
2c41180e88
commit
599ce84a71
4 changed files with 198 additions and 163 deletions
|
@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-fa" || arg == "--flash-attn") {
|
||||||
|
params.flash_attn = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--color") {
|
if (arg == "--color") {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
return true;
|
return true;
|
||||||
|
@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
cparams.cb_eval = params.cb_eval;
|
cparams.cb_eval = params.cb_eval;
|
||||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||||
cparams.offload_kqv = !params.no_kv_offload;
|
cparams.offload_kqv = !params.no_kv_offload;
|
||||||
|
cparams.flash_attn = params.flash_attn;
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||||
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
||||||
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
||||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||||
|
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||||
|
|
||||||
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
||||||
|
|
|
@ -148,6 +148,7 @@ struct gpt_params {
|
||||||
bool multiline_input = false; // reverse the usage of `\`
|
bool multiline_input = false; // reverse the usage of `\`
|
||||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
|
bool flash_attn = false; // flash attention
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool ignore_eos = false; // ignore generated EOS tokens
|
bool ignore_eos = false; // ignore generated EOS tokens
|
||||||
|
|
173
llama.cpp
173
llama.cpp
|
@ -107,8 +107,6 @@
|
||||||
#define LLAMA_MAX_NODES 8192
|
#define LLAMA_MAX_NODES 8192
|
||||||
#define LLAMA_MAX_EXPERTS 60
|
#define LLAMA_MAX_EXPERTS 60
|
||||||
|
|
||||||
#define LLAMA_FLASH_ATTN
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// logging
|
// logging
|
||||||
//
|
//
|
||||||
|
@ -1899,6 +1897,7 @@ struct llama_cparams {
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
bool causal_attn;
|
bool causal_attn;
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@ -5938,15 +5937,17 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||||
static void llm_build_kv_store(
|
static void llm_build_kv_store(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache & kv,
|
const llama_kv_cache & kv,
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
struct ggml_tensor * k_cur,
|
struct ggml_tensor * k_cur,
|
||||||
struct ggml_tensor * v_cur,
|
struct ggml_tensor * v_cur,
|
||||||
int64_t n_ctx,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t kv_head,
|
int32_t kv_head,
|
||||||
const llm_build_cb & cb,
|
const llm_build_cb & cb,
|
||||||
int64_t il) {
|
int64_t il) {
|
||||||
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
|
@ -5959,20 +5960,18 @@ static void llm_build_kv_store(
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||||
|
|
||||||
#if defined(LLAMA_FLASH_ATTN)
|
if (cparams.flash_attn) {
|
||||||
// NOTE: the V cache is not transposed when using FLASH attention !!
|
// NOTE: the V cache is not transposed when using FLASH attention !!
|
||||||
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
||||||
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
|
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
|
||||||
cb(v_cache_view, "v_cache_view", il);
|
cb(v_cache_view, "v_cache_view", il);
|
||||||
|
|
||||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
||||||
|
} else {
|
||||||
GGML_UNUSED(n_ctx);
|
|
||||||
#else
|
|
||||||
// compute the transposed [n_tokens, n_embd] V matrix
|
// compute the transposed [n_tokens, n_embd] V matrix
|
||||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
||||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur);
|
||||||
cb(v_cur_t, "v_cur_t", il);
|
cb(v_cur_t, "v_cur_t", il);
|
||||||
|
|
||||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||||
|
@ -5980,7 +5979,7 @@ static void llm_build_kv_store(
|
||||||
(kv_head)*ggml_element_size(kv.v_l[il]));
|
(kv_head)*ggml_element_size(kv.v_l[il]));
|
||||||
|
|
||||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_tensor * llm_build_norm(
|
static struct ggml_tensor * llm_build_norm(
|
||||||
|
@ -6111,6 +6110,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache & kv,
|
const llama_kv_cache & kv,
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
struct ggml_tensor * wo,
|
struct ggml_tensor * wo,
|
||||||
|
@ -6118,12 +6118,12 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * kq_mask,
|
struct ggml_tensor * kq_mask,
|
||||||
struct ggml_tensor * kq_pos,
|
struct ggml_tensor * kq_pos,
|
||||||
int64_t n_ctx,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
const llm_build_cb & cb,
|
const llm_build_cb & cb,
|
||||||
int il) {
|
int il) {
|
||||||
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
const int64_t n_head = hparams.n_head;
|
const int64_t n_head = hparams.n_head;
|
||||||
const int64_t n_head_kv = hparams.n_head_kv;
|
const int64_t n_head_kv = hparams.n_head_kv;
|
||||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
@ -6143,7 +6143,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
#if defined(LLAMA_FLASH_ATTN)
|
if (cparams.flash_attn) {
|
||||||
GGML_UNUSED(model);
|
GGML_UNUSED(model);
|
||||||
GGML_UNUSED(n_ctx);
|
GGML_UNUSED(n_ctx);
|
||||||
|
|
||||||
|
@ -6159,7 +6159,10 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
|
|
||||||
|
if (model.arch == LLM_ARCH_PHI2) {
|
||||||
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
|
}
|
||||||
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||||
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||||
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||||
|
@ -6167,7 +6170,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
|
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
||||||
#else
|
} else {
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
@ -6233,7 +6236,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||||
cb(cur, "kqv_merged_cont", il);
|
cb(cur, "kqv_merged_cont", il);
|
||||||
#endif
|
}
|
||||||
|
|
||||||
ggml_build_forward_expand(graph, cur);
|
ggml_build_forward_expand(graph, cur);
|
||||||
|
|
||||||
|
@ -6253,6 +6256,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache & kv,
|
const llama_kv_cache & kv,
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
struct ggml_tensor * wo,
|
struct ggml_tensor * wo,
|
||||||
|
@ -6262,7 +6266,6 @@ static struct ggml_tensor * llm_build_kv(
|
||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * kq_mask,
|
struct ggml_tensor * kq_mask,
|
||||||
struct ggml_tensor * kq_pos,
|
struct ggml_tensor * kq_pos,
|
||||||
int64_t n_ctx,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t kv_head,
|
int32_t kv_head,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
|
@ -6276,12 +6279,12 @@ static struct ggml_tensor * llm_build_kv(
|
||||||
ggml_build_forward_expand(graph, k_cur);
|
ggml_build_forward_expand(graph, k_cur);
|
||||||
ggml_build_forward_expand(graph, v_cur);
|
ggml_build_forward_expand(graph, v_cur);
|
||||||
|
|
||||||
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
|
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
||||||
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
|
q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
|
@ -6323,6 +6326,8 @@ struct llm_build_context {
|
||||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||||
const int32_t n_orig_ctx;
|
const int32_t n_orig_ctx;
|
||||||
|
|
||||||
|
const bool flash_attn;
|
||||||
|
|
||||||
const enum llama_pooling_type pooling_type;
|
const enum llama_pooling_type pooling_type;
|
||||||
const enum llama_rope_type rope_type;
|
const enum llama_rope_type rope_type;
|
||||||
|
|
||||||
|
@ -6369,6 +6374,7 @@ struct llm_build_context {
|
||||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||||
|
flash_attn (cparams.flash_attn),
|
||||||
pooling_type (cparams.pooling_type),
|
pooling_type (cparams.pooling_type),
|
||||||
rope_type (hparams.rope_type),
|
rope_type (hparams.rope_type),
|
||||||
cb (cb),
|
cb (cb),
|
||||||
|
@ -6483,15 +6489,31 @@ struct llm_build_context {
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||||
|
|
||||||
ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
ggml_tensor * view_v_src;
|
||||||
|
ggml_tensor * view_v_dst;
|
||||||
|
|
||||||
|
if (flash_attn) {
|
||||||
|
// NOTE: the V cache is not transposed when using flash attention
|
||||||
|
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||||
|
n_embd_v_gqa, nm,
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||||
|
|
||||||
|
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||||
|
n_embd_v_gqa, nm,
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||||
|
} else {
|
||||||
|
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
ggml_row_size(kv_self.v_l[il]->type, i));
|
||||||
|
|
||||||
ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
ggml_row_size(kv_self.v_l[il]->type, id));
|
||||||
|
}
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||||
|
@ -6640,9 +6662,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -6846,9 +6868,9 @@ struct llm_build_context {
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -6953,9 +6975,9 @@ struct llm_build_context {
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7073,9 +7095,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7198,9 +7220,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7343,9 +7365,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7447,9 +7469,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7651,9 +7673,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7747,9 +7769,9 @@ struct llm_build_context {
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8040,9 +8062,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8171,14 +8193,15 @@ struct llm_build_context {
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
} else {
|
} else {
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
|
||||||
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8320,9 +8343,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8438,9 +8461,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8557,9 +8580,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8671,9 +8694,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8817,9 +8840,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8919,9 +8942,9 @@ struct llm_build_context {
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
struct ggml_tensor * sa_out = cur;
|
struct ggml_tensor * sa_out = cur;
|
||||||
|
|
||||||
|
@ -9022,9 +9045,9 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9129,9 +9152,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9245,9 +9268,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9362,9 +9385,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9492,9 +9515,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9613,9 +9636,9 @@ struct llm_build_context {
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9732,9 +9755,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -10022,9 +10045,9 @@ struct llm_build_context {
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -11016,7 +11039,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||||
// each move requires 6*n_layer tensors (see build_defrag)
|
// each move requires 6*n_layer tensors (see build_defrag)
|
||||||
// - source view, destination view, copy operation
|
// - source view, destination view, copy operation
|
||||||
// - x2 for keys and values
|
// - x2 for keys and values
|
||||||
const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
|
//const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
|
||||||
|
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
||||||
|
const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
|
||||||
|
|
||||||
// determine which KV cells to move where
|
// determine which KV cells to move where
|
||||||
//
|
//
|
||||||
|
@ -14626,6 +14651,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.embeddings =*/ false,
|
/*.embeddings =*/ false,
|
||||||
/*.offload_kqv =*/ true,
|
/*.offload_kqv =*/ true,
|
||||||
|
/*.flash_attn =*/ false,
|
||||||
/*.abort_callback =*/ nullptr,
|
/*.abort_callback =*/ nullptr,
|
||||||
/*.abort_callback_data =*/ nullptr,
|
/*.abort_callback_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
@ -14795,6 +14821,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.embeddings = params.embeddings;
|
cparams.embeddings = params.embeddings;
|
||||||
cparams.offload_kqv = params.offload_kqv;
|
cparams.offload_kqv = params.offload_kqv;
|
||||||
|
cparams.flash_attn = params.flash_attn;
|
||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
|
|
||||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -270,6 +270,7 @@ extern "C" {
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
|
bool flash_attn; // whether to use flash attention
|
||||||
|
|
||||||
// Abort callback
|
// Abort callback
|
||||||
// if it returns true, execution of llama_decode() will be aborted
|
// if it returns true, execution of llama_decode() will be aborted
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue