llama : fix Mamba inference for pipeline parallelism

Tested to work correctly with both `main` and `parallel` examples.
This commit is contained in:
Francis Couture-Harpin 2024-03-12 14:54:35 -04:00
parent 4ddccc2852
commit 937966d75e

147
llama.cpp
View file

@ -2082,7 +2082,7 @@ struct llama_context {
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
@ -5518,6 +5518,9 @@ struct llm_build_context {
lctx.inp_K_shift = nullptr; lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr; lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr; lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
} }
void free() { void free() {
@ -5559,14 +5562,14 @@ struct llm_build_context {
GGML_ASSERT(kv_self.recurrent); GGML_ASSERT(kv_self.recurrent);
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); struct ggml_tensor * state_copy = build_inp_s_copy();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
// TODO: name the intermediate tensors with cb() // TODO: name the intermediate tensors with cb()
@ -5665,6 +5668,27 @@ struct llm_build_context {
return lctx.inp_cls; return lctx.inp_cls;
} }
struct ggml_tensor * build_inp_s_copy() {
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
cb(lctx.inp_s_copy, "inp_s_copy", -1);
ggml_set_input(lctx.inp_s_copy);
return lctx.inp_s_copy;
}
struct ggml_tensor * build_inp_s_mask() {
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
cb(lctx.inp_s_mask, "inp_s_mask", -1);
ggml_set_input(lctx.inp_s_mask);
return lctx.inp_s_mask;
}
struct ggml_tensor * build_inp_s_seq() {
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
cb(lctx.inp_s_seq, "inp_s_seq", -1);
ggml_set_input(lctx.inp_s_seq);
return lctx.inp_s_seq;
}
struct ggml_cgraph * build_llama() { struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -8148,12 +8172,8 @@ struct llm_build_context {
// {n_embd, n_tokens} // {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); struct ggml_tensor * state_seq = build_inp_s_seq();
lctx.inp_s_mask = state_mask;
lctx.inp_s_seq = state_seq;
ggml_set_input(state_mask);
ggml_set_input(state_seq);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
@ -8205,7 +8225,7 @@ struct llm_build_context {
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
ggml_cpy(ctx0, ggml_cpy(ctx0,
ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
// extract x from x_conv // extract x from x_conv
x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
@ -8239,7 +8259,7 @@ struct llm_build_context {
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
ggml_cpy(ctx0, ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states)))); ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
} }
if (batch.pos) { if (batch.pos && lctx.inp_pos) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
"non-causal attention with generative models is not supported" "non-causal attention with generative models is not supported"
); );
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (lctx.inp_KQ_mask) {
if (cparams.causal_attn) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
const int64_t n_kv = kv_self.n; if (cparams.causal_attn) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data; float * data = (float *) lctx.inp_KQ_mask->data;
// For causal attention, use only the previous KV cells // For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch. // of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j]; const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0]; const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f; float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY; f = -INFINITY;
} else { } else {
f = 0.0f; f = 0.0f;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
} }
} } else {
} else { // when using kv cache, the mask needs to match the kv cache size
// when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens;
const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data; float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0]; const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
float f = -INFINITY; float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[i]; ++s) { for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) { if (batch.seq_id[i][s] == seq_id) {
f = 0.0f; f = 0.0f;
break; break;
}
} }
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
} }
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; for (int i = n_tokens; i < n_stride; ++i) {
} data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
} }
} }
} }
@ -8582,7 +8604,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (hparams.need_kq_pos) { if (hparams.need_kq_pos) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer)); GGML_ASSERT(lctx.inp_KQ_pos);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
float * data = (float *) lctx.inp_KQ_pos->data; float * data = (float *) lctx.inp_KQ_pos->data;
@ -8594,6 +8617,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(lctx.inp_mean);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data; float * data = (float *) lctx.inp_mean->data;
@ -8625,6 +8649,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data; uint32_t * data = (uint32_t *) lctx.inp_cls->data;
@ -8645,7 +8670,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (kv_self.recurrent) { if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
{ if (lctx.inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
float * data = (float *) lctx.inp_s_mask->data; float * data = (float *) lctx.inp_s_mask->data;
@ -8667,7 +8692,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// update the correct state(s)/sequence(s) for each token of the batch. // update the correct state(s)/sequence(s) for each token of the batch.
// Like with the KQ_mask, if a token in the batch has multiple sequences, // Like with the KQ_mask, if a token in the batch has multiple sequences,
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
{ if (lctx.inp_s_seq) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
@ -9272,11 +9297,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
} }
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
llama_set_s_copy(lctx);
{ {
ggml_backend_sched_reset(lctx.sched);
ggml_cgraph * gf = llama_build_graph_s_copy(lctx); ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
ggml_backend_sched_alloc_graph(lctx.sched, gf);
llama_set_s_copy(lctx);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads); llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
need_reserve = true; need_reserve = true;