Compare commits

...
Sign in to create a new pull request.

17 commits

Author SHA1 Message Date
Georgi Gerganov
ddc59e8e0a
wipwipwiwpip 2024-05-27 12:04:09 +03:00
Francis Couture-Harpin
fc59407efe convert-hf : support Mini-Jamba conversion 2024-05-25 13:56:21 -04:00
Francis Couture-Harpin
ea2e63e9d2 convert-hf : check for unprocessed Jamba experts 2024-05-25 12:54:30 -04:00
Francis Couture-Harpin
61a88a1da3 llama : fix BERT inference without KV cache 2024-05-24 22:41:38 -04:00
Francis Couture-Harpin
0fd13e9473 Merge branch 'master' into compilade/refactor-kv-cache 2024-05-24 19:35:16 -04:00
Francis Couture-Harpin
cbc743e600 llama : support Jamba 2024-05-24 19:27:27 -04:00
Francis Couture-Harpin
7e13f19fb5 llama : rethink recurrent state cell counts
* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot
2024-05-24 16:19:25 -04:00
Francis Couture-Harpin
3b57b55c6f Merge branch 'master' into compilade/refactor-kv-cache 2024-05-22 15:34:24 -04:00
Francis Couture-Harpin
b7ec12ebf7 Merge branch 'master' into compilade/refactor-kv-cache 2024-05-12 17:13:31 -04:00
Francis Couture-Harpin
b6fafd1747 llama : remove useless return value for some llama_cache_* functions 2024-04-29 12:59:43 -04:00
Francis Couture-Harpin
c460ff1a1c Merge branch 'master' into compilade/refactor-kv-cache 2024-04-29 10:31:39 -04:00
Francis Couture-Harpin
a09db95eab llama : rename many llama_kv_cache_* functions 2024-04-29 10:24:45 -04:00
Francis Couture-Harpin
d66849f628 Merge branch 'master' into compilade/refactor-kv-cache 2024-04-09 20:33:38 -04:00
Francis Couture-Harpin
0c8b3b2095 llama : correctly handle more edge cases for the rs cache 2024-04-09 17:35:52 -04:00
Francis Couture-Harpin
0028010d01 llama : state checkpoints for recurrent models 2024-04-08 09:54:35 -04:00
Francis Couture-Harpin
8db1e4d45f llama : use std::find for seq_nodes in llama_rs_cache 2024-04-04 10:46:43 -04:00
Francis Couture-Harpin
271104c65c wip: llama : separate recurrent states from the KV cache
This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.
2024-04-03 20:47:34 -04:00
10 changed files with 2315 additions and 705 deletions

View file

@ -2338,7 +2338,7 @@ class MambaModel(Model):
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_state_size(d_state)
@ -2384,6 +2384,135 @@ class MambaModel(Model):
) )
@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused
return "gpt-2"
def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
new_name = self.map_tensor_name(merged_name)
yield new_name, data_torch
return
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield new_name, data_torch
def write_tensors(self):
super().write_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
# same as Mamba
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del n_dims # unused
return bid is not None and new_name in (
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
)
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")
class CommandR2Model(Model): class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R model_arch = gguf.MODEL_ARCH.COMMAND_R

View file

@ -187,6 +187,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
return ctx->support_simdgroup_reduction && return ctx->support_simdgroup_reduction &&
@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute(
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name); // ggml_is_contiguous(src1), src1->name);
//} //}
//if (src2) {
// GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22,
// ggml_is_contiguous(src2), src2->name);
//}
//if (dst) { //if (dst) {
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name); // dst->name);
@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} }
} break; } break;
case GGML_OP_SSM_CONV:
{
id<MTLComputePipelineState> pipeline = nil;
//pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
//[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
//[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
//[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
//[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
//[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
//[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
//[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
//[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
//[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
//[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
//[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
//[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
//[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
//[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
//[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
//[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
//[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
//[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
//[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
//[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
//[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
//[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
//[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
//[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
//[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
//if (bcast_row) {
// const int64_t n = ggml_nelements(dst)/4;
// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
//} else {
// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
//}
} break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT: case GGML_OP_CONT:

View file

@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
kernel void kernel_ssm_conv_f32(
device const float * src0,
device const float * src1,
device const float * src2,
device const int32_t * src3,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne11,
constant int64_t & ne20,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb21,
constant uint64_t & nb22,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
}
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,
device half * dst, device half * dst,

74
ggml.c
View file

@ -7094,19 +7094,18 @@ struct ggml_tensor * ggml_ssm_conv(
GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(ggml_is_matrix(x)); GGML_ASSERT(ggml_is_matrix(x));
GGML_ASSERT(ggml_is_matrix(c)); GGML_ASSERT(ggml_is_matrix(c));
GGML_ASSERT(ggml_is_matrix(sq)); GGML_ASSERT(ggml_is_vector(sq));
GGML_ASSERT(sq->type == GGML_TYPE_I32); GGML_ASSERT(sq->type == GGML_TYPE_I32);
const int64_t d_conv = c->ne[0]; const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1]; const int64_t d_inner = c->ne[1];
const int64_t n_tokens = x->ne[1]; const int64_t n_tokens = x->ne[1];
const int64_t n_kv = s->ne[2]; const int64_t n_rs = s->ne[2];
GGML_ASSERT( s->ne[0] == d_conv - 1); GGML_ASSERT( s->ne[0] == d_conv - 1);
GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT( s->ne[1] == d_inner);
GGML_ASSERT( x->ne[0] == d_inner); GGML_ASSERT( x->ne[0] == d_inner);
GGML_ASSERT(sq->ne[0] == n_kv); GGML_ASSERT(sq->ne[0] == n_tokens);
GGML_ASSERT(sq->ne[1] == n_tokens);
bool is_node = false; bool is_node = false;
@ -7115,8 +7114,8 @@ struct ggml_tensor * ggml_ssm_conv(
is_node = true; is_node = true;
} }
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));
result->op = GGML_OP_SSM_CONV; result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7169,7 +7168,7 @@ struct ggml_tensor * ggml_ssm_scan(
is_node = true; is_node = true;
} }
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
@ -16241,9 +16240,9 @@ static void ggml_compute_forward_ssm_conv_f32(
const int nc = src2->ne[0]; // d_conv const int nc = src2->ne[0]; // d_conv
const int nr = src0->ne[1]; // d_inner const int nr = src0->ne[1]; // d_inner
const int n_t = src1->ne[1]; // n_tokens const int n_t = src1->ne[1]; // n_tokens
const int n_kv = src0->ne[2]; // max number of sequences in the batch const int n_rs = src0->ne[2]; // max number of sequences in the batch
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -16260,10 +16259,12 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { const int32_t * sq = src3->data; // {n_tokens}
if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read, // multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure. // so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) { for (int i3 = 0; i3 < n_rs; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
// can't use memcpy because of d_conv vs d_conv - 1 // can't use memcpy because of d_conv vs d_conv - 1
@ -16277,19 +16278,19 @@ static void ggml_compute_forward_ssm_conv_f32(
} }
for (int i2 = 0; i2 < n_t; ++i2) { for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} int32_t sq_i = sq[i2];
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
float * s0; // {d_conv - 1, d_inner, n_kv} float * s0; // {d_conv - 1, d_inner, n_rs}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int ne0s0; int ne0s0;
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
// avoid needing to copy the state for the first token // avoid needing to copy the state for the first token
if (i2 == 0) { if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
ne0s0 = src0->ne[0]; ne0s0 = src0->ne[0];
} else { } else {
// the source is the last (d_conv - 1) columns of the destination // the source is the last (d_conv - 1) columns of the destination
@ -16307,18 +16308,6 @@ static void ggml_compute_forward_ssm_conv_f32(
s[(nc - 1) + i1*nc] = x0[i1]; s[(nc - 1) + i1*nc] = x0[i1];
} }
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
// it seems a little faster when this is separate from the state shift // it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) { for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product // rowwise dot product
@ -16370,7 +16359,7 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens in the batch const int64_t n_t = src1->ne[1]; // number of tokens in the batch
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -16379,6 +16368,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// required for the dot product between s and C, and when copying the states // required for the dot product between s and C, and when copying the states
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states // required for per-sequence offsets for states
@ -16394,10 +16384,12 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0; const int ir = ir1 - ir0;
if (n_kv > 1) { const int32_t * sq = src6->data; // {n_tokens}
if (n_rs > 1) {
// it's hard to know if the source states have already been copied // it's hard to know if the source states have already been copied
// when there are multiple, so copy them already. // when there are multiple, so copy them already.
for (int i3 = 0; i3 < n_kv; ++i3) { for (int i3 = 0; i3 < n_rs; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
memcpy(s, s0, nc*ir*sizeof(float)); memcpy(s, s0, nc*ir*sizeof(float));
@ -16405,9 +16397,9 @@ static void ggml_compute_forward_ssm_scan_f32(
} }
for (int i2 = 0; i2 < n_t; ++i2) { for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} int32_t sq_i = sq[i2];
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
float * s0; float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
@ -16415,11 +16407,11 @@ static void ggml_compute_forward_ssm_scan_f32(
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
// avoid needing to copy the state for the first token // avoid needing to copy the state for the first token
if (i2 == 0) { if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
} else { } else {
// otherwise the source is the same as the destination // otherwise the source is the same as the destination
s0 = s; s0 = s;
@ -16442,18 +16434,6 @@ static void ggml_compute_forward_ssm_scan_f32(
} }
y[i1] = sumf; y[i1] = sumf;
} }
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
} }
} }

View file

@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto() GEMMA = auto()
STARCODER2 = auto() STARCODER2 = auto()
MAMBA = auto() MAMBA = auto()
JAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
DBRX = auto() DBRX = auto()
@ -182,7 +183,10 @@ class MODEL_TENSOR(IntEnum):
SSM_CONV1D = auto() SSM_CONV1D = auto()
SSM_X = auto() SSM_X = auto()
SSM_DT = auto() SSM_DT = auto()
SSM_DT_NORM = auto()
SSM_A = auto() SSM_A = auto()
SSM_B_NORM = auto()
SSM_C_NORM = auto()
SSM_D = auto() SSM_D = auto()
SSM_OUT = auto() SSM_OUT = auto()
@ -216,6 +220,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.JAMBA: "jamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.DBRX: "dbrx",
@ -263,7 +268,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
} }
@ -682,6 +690,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.JAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.XVERSE: [ MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -385,8 +385,11 @@ class GGUFWriter:
def add_head_count(self, count: int) -> None: def add_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
def add_head_count_kv(self, count: int) -> None: def add_head_count_kv(self, count: int | Sequence[int]) -> None:
if isinstance(count, int):
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
else:
self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
def add_key_length(self, length: int) -> None: def add_key_length(self, length: int) -> None:
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)

View file

@ -206,6 +206,8 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2 "h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2 "model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"model.layers.{bid}.pre_ff_layernorm", # jamba
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
), ),
MODEL_TENSOR.FFN_GATE_INP: ( MODEL_TENSOR.FFN_GATE_INP: (
@ -214,6 +216,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate", # qwen2moe "model.layers.{bid}.mlp.gate", # qwen2moe
"transformer.decoder_layer.{bid}.router", # Grok "transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx "transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.feed_forward.router", # jamba
), ),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -245,6 +248,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_fc", # starcoder2 "model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic "model.layers.{bid}.residual_mlp.w3", # arctic
"model.layers.{bid}.feed_forward.up_proj", # jamba
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -274,6 +278,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact "transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"model.layers.{bid}.feed_forward.gate_proj", # jamba
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -309,6 +314,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_proj", # starcoder2 "model.layers.{bid}.mlp.c_proj", # starcoder2
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w2", # arctic "model.layers.{bid}.residual_mlp.w2", # arctic
"model.layers.{bid}.feed_forward.down_proj", # jamba
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
@ -350,38 +356,59 @@ class TensorNameMap:
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", "model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", "backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba
), ),
MODEL_TENSOR.SSM_CONV1D: ( MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d", "model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", "backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba
), ),
MODEL_TENSOR.SSM_X: ( MODEL_TENSOR.SSM_X: (
"model.layers.{bid}.x_proj", "model.layers.{bid}.x_proj", # mamba-hf
"backbone.layers.{bid}.mixer.x_proj", "backbone.layers.{bid}.mixer.x_proj", # mamba
"model.layers.{bid}.mamba.x_proj", # jamba
), ),
MODEL_TENSOR.SSM_DT: ( MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj", "model.layers.{bid}.dt_proj", # mamba-hf
"backbone.layers.{bid}.mixer.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba
),
MODEL_TENSOR.SSM_DT_NORM: (
"model.layers.{bid}.mamba.dt_layernorm", # jamba
), ),
MODEL_TENSOR.SSM_A: ( MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log", "model.layers.{bid}.A_log", # mamba-hf
"backbone.layers.{bid}.mixer.A_log", "backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba
),
MODEL_TENSOR.SSM_B_NORM: (
"model.layers.{bid}.mamba.b_layernorm", # jamba
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
),
MODEL_TENSOR.SSM_C_NORM: (
"model.layers.{bid}.mamba.c_layernorm", # jamba
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
), ),
MODEL_TENSOR.SSM_D: ( MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D", "model.layers.{bid}.D", # mamba-hf
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D", # mamba
"model.layers.{bid}.mamba.D", # jamba
), ),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba
), ),
} }

2348
llama.cpp

File diff suppressed because it is too large Load diff

72
llama.h
View file

@ -546,6 +546,12 @@ extern "C" {
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Rebuild and check the validity of the recurrent state cache's tree of sequences.
// (slow, use only for debugging purposes)
// Returns whether or not the rs cache was valid.
// The errors are always corrected, but only logged when debug is true.
LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug);
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@ -553,36 +559,62 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache - both cell info is erased and KV data is zeroed // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them)
LLAMA_API void llama_kv_cache_clear( LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
LLAMA_API void llama_cache_clear(
struct llama_context * ctx); struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_cache_clear instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm( // Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_cache_seq_rm(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_cache_seq_rm instead, and handle its return value for partial removals");
// Copy all tokens that belong to the specified sequence to another sequence // Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp( // Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_cache_seq_cp(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_cache_seq_cp instead, and handle its return value for partial copies");
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep( LLAMA_API void llama_cache_seq_keep(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_cache_seq_keep instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -590,12 +622,19 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add( LLAMA_API void llama_cache_seq_add(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta); llama_pos delta);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_cache_seq_add instead");
// Integer division of the positions by factor of `d > 1` // Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -603,17 +642,28 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div( LLAMA_API void llama_cache_seq_div(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d); int d);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_cache_seq_div instead");
// Returns the largest position present in the KV cache for the specified sequence // Returns the largest position present in the KV and/or RS cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max( LLAMA_API llama_pos llama_cache_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:

View file

@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case {
} }
}; };
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type_s;
const ggml_type type_x;
const ggml_type type_c;
const ggml_type type_sq;
const int64_t d_inner;
const int64_t d_conv;
const int64_t n_tokens;
const int64_t n_rs;
std::string vars() override {
return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs);
}
test_ssm_conv(ggml_type type_s = GGML_TYPE_F32,
ggml_type type_x = GGML_TYPE_F32,
ggml_type type_c = GGML_TYPE_F32,
ggml_type type_sq = GGML_TYPE_I32,
int64_t d_inner = 10,
int64_t d_conv = 10,
int64_t n_tokens = 10,
int64_t n_rs = 10)
: type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs);
ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens);
ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner);
ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// pos
std::vector<int> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = rand() % n_rs;
}
ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
enum llm_norm_type { enum llm_norm_type {
LLM_NORM, LLM_NORM,
LLM_NORM_RMS, LLM_NORM_RMS,
@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
} }
test_cases.emplace_back(new test_ssm_conv());
// these tests are disabled to save execution time, but they can be handy for debugging // these tests are disabled to save execution time, but they can be handy for debugging
#if 0 #if 0
test_cases.emplace_back(new test_llama(1)); test_cases.emplace_back(new test_llama(1));