llama : fix Mamba inference for pipeline parallelism
Tested to work correctly with both `main` and `parallel` examples.
This commit is contained in:
parent
4ddccc2852
commit
937966d75e
1 changed files with 88 additions and 59 deletions
147
llama.cpp
147
llama.cpp
|
@ -2082,7 +2082,7 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||||
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
|
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
|
||||||
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
|
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
|
@ -5518,6 +5518,9 @@ struct llm_build_context {
|
||||||
lctx.inp_K_shift = nullptr;
|
lctx.inp_K_shift = nullptr;
|
||||||
lctx.inp_mean = nullptr;
|
lctx.inp_mean = nullptr;
|
||||||
lctx.inp_cls = nullptr;
|
lctx.inp_cls = nullptr;
|
||||||
|
lctx.inp_s_copy = nullptr;
|
||||||
|
lctx.inp_s_mask = nullptr;
|
||||||
|
lctx.inp_s_seq = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void free() {
|
void free() {
|
||||||
|
@ -5559,14 +5562,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.recurrent);
|
GGML_ASSERT(kv_self.recurrent);
|
||||||
|
|
||||||
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
|
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||||
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||||
|
|
||||||
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
|
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
|
||||||
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
|
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
|
||||||
|
|
||||||
// TODO: name the intermediate tensors with cb()
|
// TODO: name the intermediate tensors with cb()
|
||||||
|
|
||||||
|
@ -5665,6 +5668,27 @@ struct llm_build_context {
|
||||||
return lctx.inp_cls;
|
return lctx.inp_cls;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_inp_s_copy() {
|
||||||
|
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
|
||||||
|
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
||||||
|
ggml_set_input(lctx.inp_s_copy);
|
||||||
|
return lctx.inp_s_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_inp_s_mask() {
|
||||||
|
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||||
|
cb(lctx.inp_s_mask, "inp_s_mask", -1);
|
||||||
|
ggml_set_input(lctx.inp_s_mask);
|
||||||
|
return lctx.inp_s_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_inp_s_seq() {
|
||||||
|
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
|
||||||
|
cb(lctx.inp_s_seq, "inp_s_seq", -1);
|
||||||
|
ggml_set_input(lctx.inp_s_seq);
|
||||||
|
return lctx.inp_s_seq;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_llama() {
|
struct ggml_cgraph * build_llama() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
@ -8148,12 +8172,8 @@ struct llm_build_context {
|
||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
struct ggml_tensor * state_mask = build_inp_s_mask();
|
||||||
struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
|
struct ggml_tensor * state_seq = build_inp_s_seq();
|
||||||
lctx.inp_s_mask = state_mask;
|
|
||||||
lctx.inp_s_seq = state_seq;
|
|
||||||
ggml_set_input(state_mask);
|
|
||||||
ggml_set_input(state_seq);
|
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
|
@ -8205,7 +8225,7 @@ struct llm_build_context {
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
|
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
|
||||||
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
|
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
|
||||||
|
|
||||||
// extract x from x_conv
|
// extract x from x_conv
|
||||||
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
|
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
|
||||||
|
@ -8239,7 +8259,7 @@ struct llm_build_context {
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
|
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
|
||||||
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
|
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
|
||||||
|
|
||||||
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
|
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
|
||||||
|
|
||||||
|
@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.pos) {
|
if (batch.pos && lctx.inp_pos) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
|
@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
"non-causal attention with generative models is not supported"
|
"non-causal attention with generative models is not supported"
|
||||||
);
|
);
|
||||||
|
|
||||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
if (lctx.inp_KQ_mask) {
|
||||||
if (cparams.causal_attn) {
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||||
const int64_t n_kv = kv_self.n;
|
if (cparams.causal_attn) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
|
|
||||||
// For causal attention, use only the previous KV cells
|
// For causal attention, use only the previous KV cells
|
||||||
// of the correct sequence for each token of the batch.
|
// of the correct sequence for each token of the batch.
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
float f;
|
float f;
|
||||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
|
}
|
||||||
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
} else {
|
// when using kv cache, the mask needs to match the kv cache size
|
||||||
// when using kv cache, the mask needs to match the kv cache size
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||||
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
|
||||||
|
|
||||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
float f = -INFINITY;
|
float f = -INFINITY;
|
||||||
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||||
if (batch.seq_id[i][s] == seq_id) {
|
if (batch.seq_id[i][s] == seq_id) {
|
||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
|
for (int i = n_tokens; i < n_stride; ++i) {
|
||||||
}
|
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
||||||
|
}
|
||||||
for (int i = n_tokens; i < n_stride; ++i) {
|
|
||||||
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8582,7 +8604,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
if (hparams.need_kq_pos) {
|
if (hparams.need_kq_pos) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
|
GGML_ASSERT(lctx.inp_KQ_pos);
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
float * data = (float *) lctx.inp_KQ_pos->data;
|
||||||
|
|
||||||
|
@ -8594,6 +8617,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(lctx.inp_mean);
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_mean->data;
|
float * data = (float *) lctx.inp_mean->data;
|
||||||
|
@ -8625,6 +8649,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(lctx.inp_cls);
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||||
|
|
||||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||||
|
@ -8645,7 +8670,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
if (kv_self.recurrent) {
|
if (kv_self.recurrent) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
{
|
if (lctx.inp_s_mask) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
||||||
float * data = (float *) lctx.inp_s_mask->data;
|
float * data = (float *) lctx.inp_s_mask->data;
|
||||||
|
|
||||||
|
@ -8667,7 +8692,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
// update the correct state(s)/sequence(s) for each token of the batch.
|
// update the correct state(s)/sequence(s) for each token of the batch.
|
||||||
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
||||||
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
|
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
|
||||||
{
|
if (lctx.inp_s_seq) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
|
||||||
|
@ -9272,11 +9297,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
|
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
|
||||||
llama_set_s_copy(lctx);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
|
ggml_backend_sched_reset(lctx.sched);
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
|
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||||
|
|
||||||
|
llama_set_s_copy(lctx);
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
||||||
|
|
||||||
need_reserve = true;
|
need_reserve = true;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue