mamba : dedicate an input tensor for state copy indices
This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers.
This commit is contained in:
parent
34e2fca8eb
commit
79d636cc7e
1 changed files with 91 additions and 31 deletions
122
llama.cpp
122
llama.cpp
|
@ -1782,6 +1782,7 @@ struct llama_layer {
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
|
int32_t src = 0; // used by recurrent state models to copy states
|
||||||
|
|
||||||
std::set<llama_seq_id> seq_id;
|
std::set<llama_seq_id> seq_id;
|
||||||
|
|
||||||
|
@ -1802,6 +1803,7 @@ struct llama_kv_cell {
|
||||||
struct llama_kv_cache {
|
struct llama_kv_cache {
|
||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
|
bool do_copy = false;
|
||||||
// with Mamba, a cell can hold the state for more than one past token
|
// with Mamba, a cell can hold the state for more than one past token
|
||||||
bool unlimited = false;
|
bool unlimited = false;
|
||||||
|
|
||||||
|
@ -2043,7 +2045,8 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||||
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
|
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||||
|
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
|
||||||
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
|
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
|
@ -2085,9 +2088,9 @@ static bool llama_kv_cache_init(
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.unlimited) {
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
cache.cells[i].delta = i;
|
cache.cells[i].src = i;
|
||||||
}
|
}
|
||||||
} // else, delta is already initialized to zero
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CLBLAST
|
#ifdef GGML_USE_CLBLAST
|
||||||
offload = false;
|
offload = false;
|
||||||
|
@ -2340,19 +2343,20 @@ static void llama_kv_cache_seq_cp(
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.unlimited) {
|
||||||
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
||||||
seq_id_src = cache.cells[seq_id_src].delta;
|
seq_id_src = cache.cells[seq_id_src].src;
|
||||||
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
|
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
|
||||||
// intent to "copy from"
|
// intent to "copy from"
|
||||||
// supports copy chains thanks to taking the source of the source
|
// supports copy chains thanks to taking the source of the source
|
||||||
cache.cells[seq_id_dst].delta = seq_id_src;
|
cache.cells[seq_id_dst].src = seq_id_src;
|
||||||
|
|
||||||
// prevent the destination from getting cleared if the source is not empty
|
// preserve the "keep or clear" status of the copied sequence
|
||||||
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
|
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
|
||||||
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
|
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
|
||||||
|
} else {
|
||||||
|
cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
|
||||||
}
|
}
|
||||||
// repurposed as a "need copy" flag
|
|
||||||
// (shifting can't be done anyway for this kind of KV cache)
|
cache.do_copy = true;
|
||||||
cache.has_shift = true;
|
|
||||||
|
|
||||||
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
|
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
|
||||||
}
|
}
|
||||||
|
@ -5445,21 +5449,7 @@ struct llm_build_context {
|
||||||
struct ggml_cgraph * build_k_shift() {
|
struct ggml_cgraph * build_k_shift() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
// TODO: do this in a another graph with a dedicated input tensor
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
if (kv_self.unlimited) {
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
|
||||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
|
|
||||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
|
|
||||||
|
|
||||||
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
|
|
||||||
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * tmp =
|
struct ggml_tensor * tmp =
|
||||||
|
@ -5479,6 +5469,25 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_s_copy() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
|
||||||
|
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
|
||||||
|
|
||||||
|
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
|
||||||
|
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
|
||||||
|
|
||||||
|
// TODO: name the intermediate tensors with cb()
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
@ -8211,6 +8220,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
|
||||||
|
llama_batch dummy;
|
||||||
|
dummy.n_tokens = 0;
|
||||||
|
|
||||||
|
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
|
||||||
|
|
||||||
|
struct llm_build_context llm(lctx, dummy, cb, false);
|
||||||
|
|
||||||
|
llm.init();
|
||||||
|
|
||||||
|
struct ggml_cgraph * result = llm.build_s_copy();
|
||||||
|
|
||||||
|
llm.free();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph(
|
static struct ggml_cgraph * llama_build_graph(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
|
@ -8350,6 +8376,18 @@ static void llama_set_k_shift(llama_context & lctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_set_s_copy(llama_context & lctx) {
|
||||||
|
const int64_t kv_size = lctx.kv_self.size;
|
||||||
|
|
||||||
|
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||||
|
|
||||||
|
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
||||||
|
|
||||||
|
for (int i = 0; i < kv_size; ++i) {
|
||||||
|
data[i] = lctx.kv_self.cells[i].src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
//
|
//
|
||||||
// set input data
|
// set input data
|
||||||
|
@ -8464,7 +8502,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_self.unlimited) {
|
if (kv_self.unlimited) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
|
||||||
|
@ -8472,9 +8510,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
// states which are not affected by the current batch are left untouched
|
// states which are not affected by the current batch are left untouched
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
llama_seq_id seq_id = i + lctx.kv_self.head;
|
llama_seq_id seq_id = i + lctx.kv_self.head;
|
||||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
|
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
|
||||||
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
bool has_self_seq = kv_cell.has_seq_id(seq_id);
|
||||||
|
|
||||||
data[i] = (float) has_self_seq;
|
data[i] = (float) has_self_seq;
|
||||||
|
|
||||||
|
@ -8998,7 +9036,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||||
|
|
||||||
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
// apply K-shift if needed
|
// apply K-shift if needed
|
||||||
if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
|
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||||
llama_set_k_shift(lctx);
|
llama_set_k_shift(lctx);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -9013,7 +9051,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
kv_self.has_shift = false;
|
kv_self.has_shift = false;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
|
kv_self.cells[i].delta = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
|
||||||
|
llama_set_s_copy(lctx);
|
||||||
|
|
||||||
|
{
|
||||||
|
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
|
||||||
|
|
||||||
|
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto & kv_self = lctx.kv_self;
|
||||||
|
|
||||||
|
kv_self.do_copy = false;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
|
kv_self.cells[i].src = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12667,7 +12725,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
// graph inputs
|
// graph inputs
|
||||||
{
|
{
|
||||||
ggml_init_params init_params = {
|
ggml_init_params init_params = {
|
||||||
/* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)),
|
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)),
|
||||||
/* .mem_buffer */ nullptr,
|
/* .mem_buffer */ nullptr,
|
||||||
/* .no_alloc */ true,
|
/* .no_alloc */ true,
|
||||||
};
|
};
|
||||||
|
@ -12682,6 +12740,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||||
if (ctx->kv_self.unlimited) {
|
if (ctx->kv_self.unlimited) {
|
||||||
|
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
||||||
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
||||||
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
|
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
|
||||||
}
|
}
|
||||||
|
@ -12695,6 +12754,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_set_name(ctx->inp_mean, "inp_mean");
|
ggml_set_name(ctx->inp_mean, "inp_mean");
|
||||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||||
if (ctx->kv_self.unlimited) {
|
if (ctx->kv_self.unlimited) {
|
||||||
|
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
|
||||||
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
||||||
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
|
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue