diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 998877c26..a42458e63 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2338,7 +2338,7 @@ class MambaModel(Model): self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) @@ -2384,6 +2384,135 @@ class MambaModel(Model): ) +@Model.register("JambaForCausalLM") +class JambaModel(Model): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield new_name, data_torch + + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + # same as Mamba + def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del n_dims # unused + + return bid is not None and new_name in ( + self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ + gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.SSM_X, + gguf.MODEL_TENSOR.SSM_DT, + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ] + ) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/ggml-metal.m b/ggml-metal.m index c9e570dbf..c39f1c151 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -187,6 +187,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const return true; case GGML_OP_FLASH_ATTN_EXT: return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && @@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute( // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, // ggml_is_contiguous(src1), src1->name); //} + //if (src2) { + // GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22, + // ggml_is_contiguous(src2), src2->name); + //} //if (dst) { // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, // dst->name); @@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; + case GGML_OP_SSM_CONV: + { + id pipeline = nil; + + //pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + //[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + //[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + //[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + //[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + //[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + //[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + //[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + //[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + //[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + //[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + //[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + //[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + //[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + //[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + //[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + //[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + //[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + //[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + //[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + //[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + //[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + //[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + //[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + //[encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + //[encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + //if (bcast_row) { + // const int64_t n = ggml_nelements(dst)/4; + + // [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + //} else { + // const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + // [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + //} + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 8ff70d7a7..0ce719cb2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +kernel void kernel_ssm_conv_f32( + device const float * src0, + device const float * src1, + device const float * src2, + device const int32_t * src3, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne11, + constant int64_t & ne20, + + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb21, + constant uint64_t & nb22, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index 9e72b7a76..b2b7b87fe 100644 --- a/ggml.c +++ b/ggml.c @@ -7094,19 +7094,18 @@ struct ggml_tensor * ggml_ssm_conv( GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_matrix(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(ggml_is_vector(sq)); GGML_ASSERT(sq->type == GGML_TYPE_I32); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t n_rs = s->ne[2]; GGML_ASSERT( s->ne[0] == d_conv - 1); GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + GGML_ASSERT(sq->ne[0] == n_tokens); bool is_node = false; @@ -7115,8 +7114,8 @@ struct ggml_tensor * ggml_ssm_conv( is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7169,7 +7168,7 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -16241,9 +16240,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int nc = src2->ne[0]; // d_conv const int nr = src0->ne[1]; // d_inner const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int n_rs = src0->ne[2]; // max number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16260,10 +16259,12 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src3->data; // {n_tokens} + + if (n_rs > 1) { // multiple sequences means it's hard to know when it's the first time a state is read, // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); // can't use memcpy because of d_conv vs d_conv - 1 @@ -16277,19 +16278,19 @@ static void ggml_compute_forward_ssm_conv_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int32_t sq_i = sq[i2]; + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} + float * s0; // {d_conv - 1, d_inner, n_rs} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} int ne0s0; - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} ne0s0 = src0->ne[0]; } else { // the source is the last (d_conv - 1) columns of the destination @@ -16307,18 +16308,6 @@ static void ggml_compute_forward_ssm_conv_f32( s[(nc - 1) + i1*nc] = x0[i1]; } - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product @@ -16370,7 +16359,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nc = src0->ne[0]; // d_state const int64_t nr = src0->ne[1]; // d_inner const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -16379,6 +16368,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states @@ -16394,10 +16384,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src6->data; // {n_tokens} + + if (n_rs > 1) { // it's hard to know if the source states have already been copied // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); memcpy(s, s0, nc*ir*sizeof(float)); @@ -16405,21 +16397,21 @@ static void ggml_compute_forward_ssm_scan_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + int32_t sq_i = sq[i2]; + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} } else { // otherwise the source is the same as the destination s0 = s; @@ -16442,18 +16434,6 @@ static void ggml_compute_forward_ssm_scan_f32( } y[i1] = sumf; } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c9ae259e1..547604807 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() STARCODER2 = auto() MAMBA = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -182,7 +183,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_OUT = auto() @@ -216,6 +220,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -263,7 +268,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", } @@ -682,6 +690,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8b41b54ea..272ef4a80 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -385,8 +385,11 @@ class GGUFWriter: def add_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) - def add_head_count_kv(self, count: int) -> None: - self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + def add_head_count_kv(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) def add_key_length(self, length: int) -> None: self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8b1b21d78..c81600151 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -206,6 +206,8 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -214,6 +216,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.feed_forward.router", # jamba ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -245,6 +248,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic + "model.layers.{bid}.feed_forward.up_proj", # jamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -274,6 +278,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic + "model.layers.{bid}.feed_forward.gate_proj", # jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -309,6 +314,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w2", # arctic + "model.layers.{bid}.feed_forward.down_proj", # jamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -350,38 +356,59 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba ), } diff --git a/llama.cpp b/llama.cpp index 3c9fe15bb..0a1385788 100644 --- a/llama.cpp +++ b/llama.cpp @@ -217,6 +217,7 @@ enum llm_arch { LLM_ARCH_GEMMA, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -254,6 +255,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -471,7 +473,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, }; @@ -969,6 +974,37 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1783,6 +1819,9 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + // TODO: find a more compact way to add more per-layer hyper-parameters + std::vector n_head_kv_vec; + float f_norm_eps; float f_norm_rms_eps; @@ -1823,6 +1862,8 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_kv_vec != other.n_head_kv_vec) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1842,29 +1883,46 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head_kv_l(uint32_t layer) const { + if (layer < n_head_kv_vec.size()) { + int32_t n_hkv_l = n_head_kv_vec[layer]; + // TODO: what should happen when it's negative? + GGML_ASSERT(n_hkv_l >= 0); + return n_hkv_l; + } + return n_head_kv; + } + + uint32_t n_gqa(uint32_t layer = 0) const { + uint32_t n_head_kv = n_head_kv_l(layer); if (n_head_kv == 0) { return 0; } return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -1913,6 +1971,9 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * ssm_dt_norm; + struct ggml_tensor * ssm_b_norm; + struct ggml_tensor * ssm_c_norm; // attention struct ggml_tensor * wq; @@ -1980,7 +2041,6 @@ struct llama_layer { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -2001,8 +2061,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching @@ -2023,9 +2081,553 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * k : k_l) { + size += ggml_nrows(k) * ggml_row_size(k->type, k->ne[0]); + } + for (struct ggml_tensor * v : v_l) { + size += ggml_nrows(v) * ggml_row_size(v->type, v->ne[0]); + } + return size; + } +}; + +// for recurrent models, use a tree of sequences to simplify +// quickly finding the tail cell of each sequence +// TODO: drop the _rs_ infix +struct llama_rs_seq_node { + llama_seq_id seq_id = -1; + int32_t next_cell = -1; + + // needed for automatic typecasting from a llama_seq_id + llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} + + // needed for more convenient std::find + bool operator==(const llama_rs_seq_node & other) const { + return seq_id == other.seq_id; + } + + bool is_tail() const { + return next_cell < 0; + } +}; + +struct llama_rs_cell { + llama_pos pos = -1; + int32_t src = -1; // copy source id (cleared next when -1) + + // Link to previous cell in this sequence. + // Sequences can only diverge, never converge, + // so this works when there are multiple seq_ids per cell too. + int32_t prev = -1; + + // ref count of tails (should match the number of next_cell == -1 in seq_nodes) + uint32_t tail_rc = 0; + + // seq_ids by insertion order, to simplify updating n_cells compared to a set + std::vector seq_nodes; + + void insert_node(const llama_rs_seq_node & node) { + auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); + if (node_dest == seq_nodes.end()) { + seq_nodes.push_back(node); + } else { + // overwrite the pre-existing node with the same seq_id if it exists + *node_dest = node; + } + } + + bool has_seq_id(const llama_seq_id & id) const { + return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); + } + + bool is_empty() const { + return seq_nodes.empty(); + } +}; + + +struct llama_rs_seq_meta { + // cell id of the latest state of this seq_id + int32_t tail = -1; + // number of cells for which this seq_id is the first + // (useful to know if cells in this sequence should be pruned) + int32_t n_cells = 0; + // changing the tail cell of a sequence can only be done at batch boundary, + // this guards against changing the cell when it shouldn't be; + // should be cleared when done finding a slot + bool in_ubatch = false; +}; + +// ring-buffered tree of cached recurrent state data +struct llama_rs_cache { + + uint32_t head = 0; // first state used for the last slot + uint32_t size = 0; + uint32_t used = 0; + + // computed when finding a slot + uint32_t n = 0; // range of states used for the last slot + + // only counts cells which are tails of all of their sequences. + // useful to know the minimum reserved cell count per seq_id. + uint32_t n_seqs = 0; + // cells part of multiple sequences, + // but which are only the tail of some of them. + // useful to dismiss sequences used as a shared prompt + uint32_t n_shared_tail_cells = 0; + + // with state models, a cell can hold the state for more than one past token + // TODO: it's probably not possible to always use contiguous cells + std::vector cells; + + // find tail cells faster + std::vector seq_tails; // map seq_ids to cell ids + + // per layer + // NOTE: the naming of r and s is arbitrary + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states + + // TODO: maybe use a simpler data structure than a tree + + // Inefficient, but thorough verification and rebuilding of the rs cache + // from only the cells list with `pos` and seq_ids. + // Should not be called in a hot loop except when desperate and/or debugging. + bool rebuild(bool debug) { + bool was_valid = true; + // skip for non-recurrent models + if (size == 0) { return true; } + // the source of truth is the cells list + // buffer sizes + if (size != cells.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", + __func__, cells.size(), size); + } + cells.resize(size); + was_valid = false; + } + if (size != seq_tails.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", + __func__, seq_tails.size(), size); + } + seq_tails.resize(size); + was_valid = false; + } + // cells consistency + uint32_t used_verif = 0; + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.seq_nodes.empty()) { + if (cell.pos >= 0) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } + cell.pos = -1; + was_valid = false; + } + } + if (cell.pos < 0) { + if (cell.pos != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } + cell.pos = -1; + was_valid = false; + } + if (!cell.seq_nodes.empty()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", + __func__, cell_id, cell.seq_nodes.size()); + } + cell.seq_nodes.clear(); + was_valid = false; + } + cell.src = -1; + if (cell.prev != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.prev); + } + cell.prev = -1; + was_valid = false; + } + } else if (!debug) { + // Assuming the cache should be actually rebuilt when not debugging + cell.src = cell_id; + } + if (!cell.seq_nodes.empty()) { + used_verif += 1; + } + } + if (used != used_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n", + __func__, used, used_verif); + } + used = used_verif; + was_valid = false; + } + // tail verification + std::vector> seq_cells; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + seq_cells.clear(); + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.has_seq_id(seq_id)) { + seq_cells.push_back({cell.pos, cell_id}); + } + } + // sort by pos and then by cell_id + std::sort(seq_cells.begin(), seq_cells.end()); + int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second; + if (tail != seq.tail) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.tail, tail); + } + seq.tail = tail; + was_valid = false; + } + int32_t prev = -1; + for (size_t i = 0; i < seq_cells.size(); ++i) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + if (cell.prev != prev) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, cell.prev, prev); + } + cell.prev = prev; + was_valid = false; + } + prev = cell_id; + } + int32_t n_cells = 0; + int32_t next = -1; + for (size_t i = seq_cells.size(); i-- > 0;) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + // assuming it's always found, because how else would it end up in the list of cells for this seq_id? + auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); + if (seq_node == cell.seq_nodes.begin()) { + n_cells += 1; + } + if (seq_node->next_cell != next) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, seq_node->next_cell, next); + } + seq_node->next_cell = next; + was_valid = false; + } + next = cell_id; + } + if (seq.n_cells != n_cells) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.n_cells, n_cells); + } + seq.n_cells = n_cells; + } + // in_batch should only be true when in the process of finding a slot + if (seq.in_ubatch != false) { + if (debug) { + LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", + __func__, seq_id); + } + seq.in_ubatch = false; + was_valid = false; + } + } + // tail_rc + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + uint32_t tail_rc = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { + tail_rc += 1; + } + } + if (cell.tail_rc != tail_rc) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", + __func__, cell_id, cell.tail_rc, tail_rc); + } + cell.tail_rc = tail_rc; + was_valid = false; + } + } + // n_seqs + uint32_t n_seqs_verif = 0; + uint32_t n_shared_tail_cells_verif = 0; + for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { + llama_rs_cell & rs_cell = cells[cell_id]; + if (!rs_cell.seq_nodes.empty()) { + if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + n_seqs_verif += 1; + } else if (rs_cell.tail_rc > 0) { + n_shared_tail_cells_verif += 1; + } + } + } + if (n_seqs != n_seqs_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", + __func__, n_seqs, n_seqs_verif); + } + n_seqs = n_seqs_verif; + was_valid = false; + } + if (n_shared_tail_cells != n_shared_tail_cells_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", + __func__, n_shared_tail_cells, n_shared_tail_cells_verif); + } + n_shared_tail_cells = n_shared_tail_cells_verif; + was_valid = false; + } + return was_valid; + } + + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + // Why an iterator? Because it allows using std::vector::erase. + std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + // The iterator needs to point inside the correct vector + GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); + if (node_iter != rs_cell.seq_nodes.end()) { + // update the tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail + cells[node.next_cell].prev = rs_cell.prev; + } + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + // move the tail back to the previous cell + if (prev_cell.seq_nodes.size() > 1) { + if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf with a single tail + if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + n_seqs -= 1; + } + } + } + prev_cell.tail_rc += 1; + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (rs_cell.tail_rc == 1) { + if (seq.tail < 0) { + // no more tail, no more sequence + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; + } + } + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { + // will fully become a tail cell + if (rs_cell.tail_rc > 0) { + n_seqs += 1; + } + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + + auto next_node = std::next(node_iter); + if (next_node != rs_cell.seq_nodes.end()) { + // the next node is the new first one, so update its n_cells + if ((uint32_t) next_node->seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node->seq_id]; + next_seq.n_cells += 1; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else { + // this was the last seq_id of the cell + used -= 1; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + // the other fields *should* have already been updated elsewhere + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + return rs_cell.seq_nodes.erase(node_iter); + } + return node_iter; + } + + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + node_iter = remove_seq_node_from_cell(rs_cell, node_iter); + } + } + + // returns whether or not the seq_id was removed + bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); + } + return false; + } + + bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < seq_tails.size()) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto & seq = seq_tails[id]; + int32_t prev = rs_cell.prev; + if ((uint32_t) seq.tail == i_cell) { + // the cell is already the tail of this seq_id + return false; + } + if (rs_cell.is_empty()) { + prev = seq.tail; + } + // ensure the new tail won't mess up the tree + GGML_ASSERT(seq.tail == -1 || seq.tail == prev); + if (prev >= 0 && (uint32_t) prev < size) { + // the targeted cell has a previous cell + llama_rs_cell & prev_cell = cells[prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing + GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken + if (rs_cell.pos < 0) { + GGML_ASSERT(rs_cell.is_empty()); + rs_cell.pos = prev_cell.pos + 1; + rs_cell.src = prev_cell.src; + } + prev_node->next_cell = i_cell; + rs_cell.prev = prev; + if (seq.tail == prev) { + // What to do when the tail moves... + // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) + // O -> oO (n_seqs--, n_shared_tail_cells++) + // O -> O (seq.n_cells++) + // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) + // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) + // _ -> oO (n_shared_tail_cells++) + // _ -> O (seq.n_cells++, n_seqs++) + // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) + // Oo -> OO+ (n_shared_tail_cell--) + // OOo -> O (seq.n_cells++, n_seqs++) + if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { + // from fully tail + if (prev_cell.tail_rc > 1) { + // the previous tail becomes shared with a non-tail + n_shared_tail_cells += 1; + } + if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { + // the new tail cell was previously a fully non-tail cell + n_shared_tail_cells += 1; + n_seqs -= 1; + } + } else if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + if (prev_cell.tail_rc == 1) { + // it was the last tail of the previous cell + n_shared_tail_cells -= 1; + } + } + } + prev_cell.tail_rc -= 1; + } + if (rs_cell.is_empty()) { + // to unique + seq.n_cells += 1; + if (seq.tail < 0) { + // from empty to unique + n_seqs += 1; + // pos was not yet set + rs_cell.pos = 0; + rs_cell.src = -1; + } + used += 1; + } else if (rs_cell.tail_rc == 0) { + // to shared + if (seq.tail < 0) { + // from empty to shared + n_shared_tail_cells += 1; + } + } + // the target cell was not already a tail of this seq_id + rs_cell.insert_node(id); // next_cell == -1 by default + rs_cell.tail_rc += 1; + seq.tail = i_cell; + return true; + } + return false; + } + + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { + uint32_t seqs = n_seqs; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; + } + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); + } + + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * r : r_l) { + size += ggml_nrows(r) * ggml_row_size(r->type, r->ne[0]); + } + for (struct ggml_tensor * s : s_l) { + size += ggml_nrows(s) * ggml_row_size(s->type, s->ne[0]); + } + return size; + } +}; + +struct llama_cache { + // key + value cache for self attention + llama_kv_cache kv; + + // recurrent state cache for state space models + llama_rs_cache rs; + std::vector ctxs; std::vector bufs; + // NOTE: padding may make this bigger than kv.total_size() + rs.total_size() size_t total_size() const { size_t size = 0; for (ggml_backend_buffer_t buf : bufs) { @@ -2034,7 +2636,7 @@ struct llama_kv_cache { return size; } - ~llama_kv_cache() { + ~llama_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -2227,8 +2829,8 @@ struct llama_context { const llama_model & model; - // key + value cache for the self attention - struct llama_kv_cache kv_self; + // key + value cache for self-attention, and/or recurrent state cache + struct llama_cache cache; std::mt19937 rng; @@ -2285,9 +2887,9 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_copy; // I32 [n_rs] + struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] + struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -2392,54 +2994,44 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { } // -// kv cache helpers +// kv and rs cache helpers // -static bool llama_kv_cache_init( - struct llama_kv_cache & cache, +static bool llama_cache_init( + struct llama_cache & cache, const llama_context * ctx, ggml_type type_k, ggml_type type_v, - uint32_t kv_size, bool offload) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn; + const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0; + const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0; + const bool has_rs = has_r || has_s; + const uint32_t kv_size = has_kv ? cparams.n_ctx : 0; + const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0; const int64_t n_layer = hparams.n_layer; - cache.has_shift = false; + cache.kv.size = kv_size; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; - cache.v_trans = !cparams.flash_attn; + cache.kv.v_trans = !cparams.flash_attn; - // TODO: support mixed recurrent Transformer architectures - // NOTE: (!a || b) is a logical implication (a -> b) - GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.kv.type_k = type_k; + cache.kv.type_v = type_v; - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + cache.kv.cells.clear(); + cache.kv.cells.resize(kv_size); - cache.type_k = type_k; - cache.type_v = type_v; + cache.rs.size = rs_size; - cache.cells.clear(); - cache.cells.resize(kv_size); - - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } + cache.rs.cells.clear(); + cache.rs.cells.resize(rs_size); + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -2459,8 +3051,9 @@ static bool llama_kv_cache_init( std::map ctx_map; for (auto & it : buft_layer_count) { int n_layers = it.second; + // TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -2473,17 +3066,37 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (has_kv) { + cache.kv.k_l.reserve(n_layer); + cache.kv.v_l.reserve(n_layer); + } + if (has_r) { + cache.rs.r_l.reserve(n_layer); + } + if (has_s) { + cache.rs.s_l.reserve(n_layer); + } for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + if (has_kv) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.kv.k_l.push_back(k); + cache.kv.v_l.push_back(v); + } + if (has_r) { + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); + ggml_format_name(r, "cache_r_l%d", i); + cache.rs.r_l.push_back(r); + } + if (has_s) { + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); + ggml_format_name(s, "cache_s_l%d", i); + cache.rs.s_l.push_back(s); + } } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -2492,11 +3105,15 @@ static bool llama_kv_cache_init( ggml_context * ctx = it.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { + if (!has_kv && !has_rs) { + // no buffer was needed, so this is fine + return true; + } LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -2507,107 +3124,255 @@ static bool llama_kv_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { +static bool llama_cache_find_slot( + struct llama_cache & cache, + const struct llama_batch & batch) { + const uint32_t kv_size = cache.kv.size; + const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.recurrent) { - // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. - - llama_seq_id min = cache.size - 1; - llama_seq_id max = 0; - + // only check first, to allow failing gracefully + if (rs_size > 0) { + // everything should fit if all seq_ids are smaller than the max for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { - if (seq_id > max) { - max = seq_id; - } - if (seq_id < min) { - min = seq_id; - } - // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); - } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; - } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set - } else { + + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + // TODO: would it be possible to resize the rs cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } } - - // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; - - // sanity check - return max >= min; - } - // otherwise, one cell per token. - - if (n_tokens > cache.size) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; } - uint32_t n_tested = 0; - - while (true) { - if (cache.head + n_tokens > cache.size) { - n_tested += cache.size - cache.head; - cache.head = 0; - continue; + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); + return false; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (cache.kv.head > cache.kv.used + 2*n_tokens) { + cache.kv.head = 0; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.kv.head + n_tokens > kv_size) { + n_tested += kv_size - cache.kv.head; + cache.kv.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.kv.cells[cache.kv.head + i].pos >= 0) { + found = false; + cache.kv.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { break; } + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + } + + // now modification can be done, and should NOT fail + + if (rs_size > 0) { + // For recurrent state architectures (like Mamba), + // each cache cell can store the state for a whole sequence. + // TODO: find a way to always make the rs slot contiguous + + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; + + for (uint32_t i = 0; i < n_tokens; ++i) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size + GGML_ASSERT((uint32_t) seq_id < rs_size); + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + // this seq_id has no tail (and is empty) + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } + } + + // reserve a cell for this seq_id + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + // avoid infinite loop + // NOTE: this should not fail; if it does, it's a bug. + GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } + } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + + // Assuming the tokens are in-order + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); + } + } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; + } + + // allow getting the range of used cells, from head to head + n + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; + + // sanity check + GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + } + + if (kv_size > 0) { + for (uint32_t i = 0; i < n_tokens; i++) { + cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + } } - if (found) { - break; - } - - if (n_tested >= cache.size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } + cache.kv.used += n_tokens; } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; - - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); - } - } - - cache.used += n_tokens; - return true; } -// find how many cells are currently in use +// find how many KV cells are currently in use static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -2620,218 +3385,395 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - } - cache.head = 0; - cache.used = 0; +// find how many recurrent state cells are currently in use +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_rs_cell & cell = cache.cells[i - 1]; + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +static void llama_cache_clear(struct llama_cache & cache) { + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + kv_cell.pos = -1; + kv_cell.delta = 0; + kv_cell.seq_id.clear(); + } + cache.kv.has_shift = false; + cache.kv.do_defrag = false; + cache.kv.head = 0; + cache.kv.used = 0; + } + if (cache.rs.size > 0) { + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + } + cache.rs.head = 0; + cache.rs.used = 0; + cache.rs.n_seqs = 0; + cache.rs.n_shared_tail_cells = 0; + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(cache.rs.size); + } for (auto & buf : cache.bufs) { ggml_backend_buffer_clear(buf, 0); } } -static bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_rm( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } - // models like Mamba can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { + if (seq_id >= (int64_t) cache.rs.size) { // could be fatal - return false; + return n_past; } - if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - } + uint32_t new_head = cache.rs.size; + // adjust p0 and p1 according to the states found + llama_pos new_p0 = 0; + llama_pos new_p1 = std::numeric_limits::max(); - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); + // partial seq_id removal has to happen from the tail + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + // copy before the cell is potentially changed + int32_t prev_id = rs_cell.prev; + if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { + // non-tail removal for shared cells can only be done when clearing a cell + // (i.e. when the next cell's link to the previous cell can be safely changed) + p1 = rs_cell.pos + 1; + } + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // if the node isn't found, the sequence tree is malformed + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + // get the smallest removed cell id + if (new_head > (uint32_t) cell_id) { new_head = cell_id; } } else { - continue; - } - if (cache.cells[i].is_empty()) { - // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + // one more than the biggest non-removed cell of this sequence + if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; + if (rs_cell.pos < p0) { + // new_p0 should be right after the max pos in the states before p0 + if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } + } else { // (rs_cell.pos >= p1) + // new_p1 should be the min pos in the states after p1 + if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } + } + } + cell_id = prev_id; + } + p0 = new_p0; + p1 = new_p1; + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } + } + + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + + if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + if (seq_id < 0) { + kv_cell.seq_id.clear(); + } else { // (kv_cell.has_seq_id(seq_id)) + kv_cell.seq_id.erase(seq_id); + } + if (kv_cell.is_empty()) { + // keep count of the number of used cells + if (kv_cell.pos >= 0) { cache.kv.used--; } + + kv_cell.pos = -1; + if (new_head == cache.kv.size) { new_head = i; } + } + } else if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; + } + } + + return n_past; +} + +static llama_pos llama_cache_seq_cp( + struct llama_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } + + // TODO: in practice this seems to be only used on whole sequences; + // should partial sequence copy support be removed? + // TODO: What if the destination sequence is not empty? + + llama_pos n_past = 0; + + if (cache.rs.size > 0) { + // have to start from the beginning for recurrent models + p0 = 0; + if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { + int32_t src_head = -1; + int32_t head_pos = p1; + int32_t src_next = -1; + // find the start of the sequence + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.is_empty() && rs_cell.prev < 0) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + if (seq_node != rs_cell.seq_nodes.end()) { + src_head = i; + head_pos = rs_cell.pos; + src_next = seq_node->next_cell; + break; + } + } + } + while (src_head >= 0 && head_pos < p1) { + cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); + src_head = src_next; + if (head_pos >= n_past) { n_past = head_pos + 1; } + if (src_next >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[src_next]; + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + head_pos = rs_cell.pos; + // it should always be found if the seq tree is valid + GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); + src_next = seq_node->next_cell; + } + } + } + p1 = n_past; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { + kv_cell.seq_id.insert(seq_id_dst); + if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; - - return true; + return n_past; } -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); +static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.rs.size > 0) { + uint32_t new_head = cache.rs.size; - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; + // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (i == (uint32_t) seq_id) { continue; } + llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + cell_id = rs_cell.prev; + if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { + new_head = cell_id; + } + } + } - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } + } + + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (!kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= 0) { cache.kv.used--; } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { new_head = i; } } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(seq_id); } - - cache.do_copy = true; - - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; } - return; - } - // otherwise, this is the KV cache of a Transformer-like model - cache.head = 0; - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; } } } -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; +static void llama_cache_seq_add( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { - for (uint32_t i = 0; i < cache.size; ++i) { - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); - } - } + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; -} - -static void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; - - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - - if (cache.recurrent) { + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - return; - } - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; - - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; - } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { - new_head = i; + auto & seq = cache.rs.seq_tails[seq_id]; + // follow the sequence from its tail + int32_t cell_id = seq.tail; + uint32_t new_head = cache.rs.size; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + cell_id = rs_cell.prev; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + if (rs_cell.pos < 0) { + // NOTE: this affects the other sequences which share the cell + cache.rs.clear_cell(rs_cell); + if (new_head > (uint32_t) cell_id) { + new_head = cell_id; + } } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.rs.head = new_head != cache.rs.size ? new_head : 0; } - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + kv_cell.pos += delta; + kv_cell.delta += delta; + + if (kv_cell.pos < 0) { + if (!kv_cell.is_empty()) { + cache.kv.used--; + } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { + new_head = i; + } + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.kv.head = new_head != cache.kv.size ? new_head : 0; + } } -static void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); +static void llama_cache_seq_div( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } - if (cache.recurrent) { + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + auto & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; } + cell_id = rs_cell.prev; } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; - { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + { + llama_pos p_old = kv_cell.pos; + kv_cell.pos /= d; + kv_cell.delta += kv_cell.pos - p_old; + } + } } } } } -static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { - llama_pos result = 0; +static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { + llama_pos result = -1; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + if (cache.rs.size > 0) { + int32_t cell_id = cache.rs.seq_tails[seq_id].tail; + if (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + result = rs_cell.pos; + } + // exit early + return result; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + result = std::max(result, kv_cell.pos); + } } } @@ -3352,9 +4294,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -3362,16 +4304,17 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + // TODO: allow ANY lossless cast + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); } - // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); - - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + result.reserve(arr_info.length); + result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } @@ -3916,7 +4859,12 @@ static void llm_load_hparams( // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + + // per-layer n_head_kv + if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) { + // global/fallback n_head_kv + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + } bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -4284,6 +5232,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5929,10 +6893,7 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -6050,6 +7011,118 @@ static bool llm_load_tensors( layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + GGML_ASSERT((int64_t) hparams.n_head_kv_vec.size() == n_layer); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv_vec[i]; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + layer.ssm_b_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}); + layer.ssm_c_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + + layer.wq = nullptr; + layer.wk = nullptr; + layer.wv = nullptr; + layer.wo = nullptr; + + } else { + // Attention layers + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ssm_in = nullptr; + layer.ssm_conv1d = nullptr; + layer.ssm_conv1d_b = nullptr; + layer.ssm_x = nullptr; + layer.ssm_dt_norm = nullptr; + layer.ssm_dt = nullptr; + layer.ssm_dt_b = nullptr; + layer.ssm_b_norm = nullptr; + layer.ssm_c_norm = nullptr; + layer.ssm_a = nullptr; + layer.ssm_d = nullptr; + layer.ssm_out = nullptr; + } + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + + layer.ffn_gate = nullptr; + layer.ffn_down = nullptr; + layer.ffn_up = nullptr; + } else { + // FFN (no MoE) + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.ffn_gate_exps = nullptr; + layer.ffn_down_exps = nullptr; + layer.ffn_up_exps = nullptr; + } + } + } break; case LLM_ARCH_XVERSE: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6498,8 +7571,8 @@ static void llm_build_kv_store( int64_t il) { const int64_t n_ctx = cparams.n_ctx; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(kv.size == n_ctx); @@ -6765,9 +7838,9 @@ static struct ggml_tensor * llm_build_kqv( int il) { const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head_kv = hparams.n_head_kv_l(il); const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -6903,6 +7976,132 @@ static struct ggml_tensor * llm_build_kv( return cur; } +// TODO: split +static struct ggml_tensor * llm_build_mamba( + struct ggml_context * ctx, + const llama_model & model, + const llama_hparams & hparams, + const llama_rs_cache & rs, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + struct ggml_tensor * state_seq, + struct ggml_tensor * w_dt_norm, + struct ggml_tensor * w_b_norm, + struct ggml_tensor * w_c_norm, + int32_t n_tokens, + int32_t rs_head, + int32_t n_rs, + const llm_build_cb & cb, + int il) { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); + } + + // clear states of sequences which are starting at the beginning of this batch + { + conv_states = ggml_mul(ctx, conv_states, state_mask); + ssm_states = ggml_mul(ctx, ssm_states, state_mask); + } + + conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + + // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_tokens} + struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); + struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + + // conv + { + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner * n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + + // extract x from x_conv + x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + + // bias + x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + + if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + + // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, + // because only a single tensor can be returned. + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + + // store last states (the second part of y_ssm_states) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + + struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + + // TODO: skip computing output for unused tokens + + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx, y, ggml_silu(ctx, z)); + + // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); + } + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -6910,6 +8109,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_batch & batch; const llama_kv_cache & kv_self; + const llama_rs_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -6918,9 +8118,7 @@ struct llm_build_context { const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; - const int64_t n_embd_k_gqa; const int64_t n_embd_head_v; - const int64_t n_embd_v_gqa; const int64_t n_expert; const int64_t n_expert_used; @@ -6935,8 +8133,10 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_rs; const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_head; const int32_t n_orig_ctx; const bool flash_attn; @@ -6961,7 +8161,8 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), batch (batch), - kv_self (lctx.kv_self), + kv_self (lctx.cache.kv), + rs_self (lctx.cache.rs), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -6969,9 +8170,7 @@ struct llm_build_context { n_head (hparams.n_head), n_head_kv (hparams.n_head_kv), n_embd_head_k (hparams.n_embd_head_k), - n_embd_k_gqa (hparams.n_embd_k_gqa()), n_embd_head_v (hparams.n_embd_head_v), - n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), @@ -6984,8 +8183,10 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + n_rs (worst_case ? rs_self.size : rs_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), + kv_head (worst_case ? kv_self.size - n_tokens : kv_self.head), + rs_head (worst_case ? 0 : rs_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -7040,9 +8241,9 @@ struct llm_build_context { // we rotate only the first n_rot dimensions ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_head_kv, n_ctx, + n_embd_head_k, hparams.n_head_kv_l(il), n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -7054,29 +8255,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -7094,6 +8272,9 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, nm, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), @@ -7193,21 +8374,21 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; } struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_rs); cb(lctx.inp_s_mask, "inp_s_mask", -1); ggml_set_input(lctx.inp_s_mask); return lctx.inp_s_mask; } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -10313,12 +11494,62 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - const int64_t d_model = n_embd; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * state_seq = build_inp_s_seq(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, NULL, NULL, NULL, + n_tokens, rs_head, n_rs, cb, il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_jamba() { + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -10326,116 +11557,92 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_seq = build_inp_s_seq(); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + const int64_t n_head_kv = hparams.n_head_kv_l(il); - // clear states of sequences which are starting at the beginning of this batch - { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); - } - - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); - - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); - // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + if (n_head_kv == 0) { + // Mamba + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, + model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, + n_tokens, rs_head, n_rs, cb, il); + } else { + // Attention - // conv - { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - // extract x from x_conv - x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); - x = ggml_silu(ctx0, x); + // No RoPE :) + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } - // ssm - { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); - // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); - - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); - - struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - x = ggml_get_rows(ctx0, x, inp_out_ids); - y = ggml_get_rows(ctx0, y, inp_out_ids); - z = ggml_get_rows(ctx0, z, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); - - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } // residual - cur = ggml_add(ctx0, cur, inpL); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + cb, il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); cb(cur, "l_out", il); // input for next layer @@ -11041,23 +12248,6 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_s_copy(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -11199,6 +12389,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(); } break; + case LLM_ARCH_JAMBA: + { + result = llm.build_jamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -11233,26 +12427,14 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + const int64_t kv_size = lctx.cache.kv.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; + data[i] = lctx.cache.kv.cells[i].delta; } } @@ -11263,7 +12445,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; + const auto & kv_self = lctx.cache.kv; + const auto & rs_self = lctx.cache.rs; if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -11339,11 +12522,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { f = -INFINITY; } else { if (hparams.use_alibi) { - f = -fabs(lctx.kv_self.cells[i].pos - pos); + f = -fabs(kv_self.cells[i].pos - pos); } else { f = 0.0f; } @@ -11448,29 +12631,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + if (rs_self.size > 0) { + const int64_t n_rs = rs_self.n; if (lctx.inp_s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // clear unused states + for (int i = 0; i < n_rs; ++i) { + uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) rs_cell.src >= 0; - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (rs_cell.src < 0) { + rs_cell.src = cell_id; } } } + + // checkpoints require copies between cells + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + const uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; + + // prevent out-of-bound sources + if (rs_cell.src < 0 || (uint32_t) rs_cell.src >= rs_self.size) { + rs_cell.src = cell_id; + } + + data[i] = rs_cell.src; + + // ensure copy only happens once + if (rs_cell.src != (int32_t) cell_id) { + rs_cell.src = cell_id; + } + } + } + // For Mamba (and other recurrent architectures), // update the correct state(s)/sequence(s) for each token of the batch. + // Each row contains relative cell ids of the sequences for the associated token. // Like with the KQ_mask, if a token in the batch has multiple sequences, // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). if (lctx.inp_s_seq) { @@ -11479,18 +12687,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); int32_t * data = (int32_t *) lctx.inp_s_seq->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - for (int i = 0; i < n_kv; ++i) { - if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; - } else { - data[j*n_kv + i] = -1; - } - } + data[i] = seq.tail - rs_self.head; } } } @@ -11619,7 +12824,8 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; + auto & rs_self = lctx.cache.rs; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -11735,17 +12941,11 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_cache_find_slot(lctx.cache, u_batch)) { return 1; } - if (!kv_self.recurrent) { + if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -11820,11 +13020,15 @@ static int llama_decode_internal( // update the kv ring buffer { kv_self.head += n_tokens; + rs_self.head += rs_self.n; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + if (rs_self.head >= rs_self.size) { + rs_self.head = 0; + } } #ifdef GGML_PERF @@ -11898,6 +13102,10 @@ static int llama_decode_internal( } } n_outputs_prev += lctx.n_outputs; + +#ifndef NDEBUG + GGML_ASSERT(lctx.cache.rs.rebuild(true)); +#endif } // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -11928,7 +13136,7 @@ static int llama_decode_internal( // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; const auto & hparams = lctx.model.hparams; @@ -12151,7 +13359,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.cache.kv.has_shift) { { ggml_backend_sched_reset(lctx.sched); @@ -12167,7 +13375,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; kv_self.has_shift = false; @@ -12177,39 +13385,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed - if (lctx.kv_self.do_defrag) { + if (lctx.cache.kv.do_defrag) { llama_kv_cache_defrag_internal(lctx); need_reserve = true; - lctx.kv_self.do_defrag = false; + lctx.cache.kv.do_defrag = false; } // reserve a worst case graph again @@ -15948,18 +17130,8 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -16077,25 +17249,42 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + if (ctx->cache.rs.size > 0) { + size_t memory_size_r = 0; + size_t memory_size_s = 0; + + for (auto & r : ctx->cache.rs.r_l) { + memory_size_r += ggml_nbytes(r); + } + + for (auto & s : ctx->cache.rs.s_l) { + memory_size_s += ggml_nbytes(s); + } + + LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); + } + if (ctx->cache.kv.size > 0) { size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { + for (auto & k : ctx->cache.kv.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { + for (auto & v : ctx->cache.kv.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -16203,7 +17392,11 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; + if (ctx->cache.rs.size > 0) { + return ctx->cache.rs.size; + } else { + return ctx->cache.kv.size; + } } enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { @@ -16219,6 +17412,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: return LLAMA_ROPE_TYPE_NONE; @@ -16501,8 +17695,9 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { - view->n_cells = int32_t(ctx->kv_self.size); + const llama_kv_cache & kv_self = ctx->cache.kv; + if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -16511,7 +17706,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = ctx->kv_self.cells; + const std::vector & kv_cells = kv_self.cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -16520,7 +17715,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -16558,67 +17753,118 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != ctx->kv_self.used) { + if (uint32_t(used_cells) != kv_self.used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, ctx->kv_self.used, used_cells); + __func__, kv_self.used, used_cells); } } +bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug) { + return ctx->cache.rs.rebuild(debug); +} + int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; - for (uint32_t i = 0; i < ctx->kv_self.size; i++) { - result += ctx->kv_self.cells[i].seq_id.size(); + for (uint32_t i = 0; i < ctx->cache.kv.size; i++) { + result += ctx->cache.kv.cells[i].seq_id.size(); } return result; } int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { - return ctx->kv_self.used; + return ctx->cache.kv.used; } +int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { + return ctx->cache.rs.used; +} + +void llama_cache_clear(struct llama_context * ctx) { + llama_cache_clear(ctx->cache); +} + +// deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_kv_cache_clear(ctx->kv_self); + llama_cache_clear(ctx); } +llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); +} + +// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); + llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); + return n_past >= p0; } -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + +llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + uint32_t n_seq_max = llama_n_seq_max(ctx); + if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { + return 0; + } if (seq_id_src == seq_id_dst) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } +// deprecated +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_cache_seq_keep(ctx->cache, seq_id); +} + +// deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); + llama_cache_seq_keep(ctx, seq_id); } +void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (delta == 0) { return; } + + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); +} + +// deprecated void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } - - llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx, seq_id, p0, p1, delta); } +void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } + + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); +} + +// deprecated void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); + llama_cache_seq_div(ctx, seq_id, p0, p1, d); } +llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } + return llama_cache_seq_pos_max(ctx->cache, seq_id); +} + +// deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + return max_pos < 0 ? 0 : max_pos; } void llama_kv_cache_defrag(struct llama_context * ctx) { - llama_kv_cache_defrag(ctx->kv_self); + llama_kv_cache_defrag(ctx->cache.kv); } void llama_kv_cache_update(struct llama_context * ctx) { @@ -16671,9 +17917,10 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); const size_t s_v_trans = sizeof(uint32_t); - const size_t s_kv = ctx->kv_self.total_size(); + const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); - const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; + const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; + // FIXME: rs cache cells const size_t s_total = ( + s_rng_size @@ -16827,14 +18074,15 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data } } + // FIXME: copy rs cache // copy kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -16860,9 +18108,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent || !kv_self.v_trans) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v + if (!kv_self.v_trans) { const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); tmp_buf.resize(v_size); @@ -16980,14 +18226,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { } } + // FIXME: set rs cache // set kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; uint32_t kv_head; @@ -17024,9 +18271,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent || !kv_self.v_trans) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v + if (!kv_self.v_trans) { const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size); @@ -17046,8 +18291,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - ctx->kv_self.head = kv_head; - ctx->kv_self.used = kv_used; + ctx->cache.kv.head = kv_head; + ctx->cache.kv.used = kv_used; for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; @@ -17056,13 +18301,13 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); - ctx->kv_self.cells[i].pos = pos; + ctx->cache.kv.cells[i].pos = pos; llama_seq_id seq_id; for (size_t j = 0; j < seq_id_size; ++j) { memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); - ctx->kv_self.cells[i].seq_id.insert(seq_id); + ctx->cache.kv.cells[i].seq_id.insert(seq_id); } } } @@ -17177,12 +18422,12 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) size_t s_cell_count = 0; size_t s_cell_data_size = 0; - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < kv_self.size; ++i) { const auto & cell = kv_self.cells[i]; @@ -17221,8 +18466,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); - const auto & kv_self = ctx->kv_self; - GGML_ASSERT(!kv_self.recurrent); // not implemented + const auto & kv_self = ctx->cache.kv; + GGML_ASSERT(ctx->cache.rs.size == 0); // not implemented // Save the size of size_t as a uint32_t for safety check const uint32_t size_t_size = sizeof(size_t); @@ -17267,8 +18512,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // Write the layer count data_ctx.write(&n_layer, sizeof(n_layer)); @@ -17361,11 +18606,12 @@ size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_s size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { llama_synchronize(ctx); - auto & kv_self = ctx->kv_self; - GGML_ASSERT(!kv_self.recurrent); // not implemented + auto & cache = ctx->cache; + auto & kv_self = cache.kv; + GGML_ASSERT(cache.rs.size == 0); // not implemented // Wipe the slot - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); const uint8_t * inp = src; @@ -17396,8 +18642,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Sanity check model compatibility const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); if (n_layer != n_layer_ref) { LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); return 0; @@ -17420,7 +18666,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(kv_self, batch)) { + if (!llama_cache_find_slot(cache, batch)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; @@ -17449,7 +18695,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_type_i_ref); const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; if (k_type_i != k_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return 0; } @@ -17460,7 +18706,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_size_row_ref); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); return 0; } @@ -17481,7 +18727,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -17492,7 +18738,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_row_ref); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); return 0; } @@ -17512,7 +18758,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -17523,7 +18769,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_el_ref); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_cache_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); return 0; } @@ -17705,11 +18951,19 @@ void llama_batch_free(struct llama_batch batch) { int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + return ret; } diff --git a/llama.h b/llama.h index 16cece5db..c5918057f 100644 --- a/llama.h +++ b/llama.h @@ -546,6 +546,12 @@ extern "C" { // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Rebuild and check the validity of the recurrent state cache's tree of sequences. + // (slow, use only for debugging purposes) + // Returns whether or not the rs cache was valid. + // The errors are always corrected, but only logged when debug is true. + LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -553,36 +559,62 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_cache_clear( + // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them) + LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed + LLAMA_API void llama_cache_clear( struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_cache_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + // Returns n_past (one more than the largest remaining pos in the seq_id) + // which is only meaningful to handle for partial removals. + LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + // Returns n_past (one more than the largest remaining pos in the destination seq_id) + // which is only meaningful to handle when partially copying. + LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -590,12 +622,19 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + LLAMA_API void llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_cache_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -603,17 +642,28 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + LLAMA_API void llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_cache_seq_div instead"); - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the largest position present in the KV and/or RS cache for the specified sequence + LLAMA_API llama_pos llama_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index de74585da..f4c194591 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case { } }; +// GGML_OP_SSM_CONV +struct test_ssm_conv : public test_case { + const ggml_type type_s; + const ggml_type type_x; + const ggml_type type_c; + const ggml_type type_sq; + const int64_t d_inner; + const int64_t d_conv; + const int64_t n_tokens; + const int64_t n_rs; + + std::string vars() override { + return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs); + } + + test_ssm_conv(ggml_type type_s = GGML_TYPE_F32, + ggml_type type_x = GGML_TYPE_F32, + ggml_type type_c = GGML_TYPE_F32, + ggml_type type_sq = GGML_TYPE_I32, + int64_t d_inner = 10, + int64_t d_conv = 10, + int64_t n_tokens = 10, + int64_t n_rs = 10) + : type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs); + ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens); + ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner); + ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens); + ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + // pos + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = rand() % n_rs; + } + ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_ssm_conv()); + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1));