Compare commits
17 commits
master
...
compilade/
Author | SHA1 | Date | |
---|---|---|---|
|
ddc59e8e0a | ||
|
fc59407efe | ||
|
ea2e63e9d2 | ||
|
61a88a1da3 | ||
|
0fd13e9473 | ||
|
cbc743e600 | ||
|
7e13f19fb5 | ||
|
3b57b55c6f | ||
|
b7ec12ebf7 | ||
|
b6fafd1747 | ||
|
c460ff1a1c | ||
|
a09db95eab | ||
|
d66849f628 | ||
|
0c8b3b2095 | ||
|
0028010d01 | ||
|
8db1e4d45f | ||
|
271104c65c |
10 changed files with 2315 additions and 705 deletions
|
@ -2338,7 +2338,7 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_embedding_length(d_model)
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||||
self.gguf_writer.add_ssm_state_size(d_state)
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
|
@ -2384,6 +2384,135 @@ class MambaModel(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("JambaForCausalLM")
|
||||||
|
class JambaModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||||
|
|
||||||
|
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||||
|
del tokenizer # unused
|
||||||
|
|
||||||
|
return "gpt-2"
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
if (self.dir_model / "tokenizer.model").is_file():
|
||||||
|
# Using Jamba's tokenizer.json causes errors on model load
|
||||||
|
# (something about "byte not found in vocab"),
|
||||||
|
# but there's a working tokenizer.model
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
else:
|
||||||
|
# Some Jamba models only have a tokenizer.json, which works.
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||||
|
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
|
||||||
|
d_inner = self.hparams["mamba_expand"] * d_model
|
||||||
|
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
|
||||||
|
# ceiling division
|
||||||
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
|
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
|
||||||
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
|
||||||
|
n_kv_head = self.hparams["num_key_value_heads"]
|
||||||
|
attn_offset = self.hparams["attn_layer_offset"]
|
||||||
|
attn_period = self.hparams["attn_layer_period"]
|
||||||
|
n_kv_vec = [0 for _ in range(attn_offset)] + [
|
||||||
|
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
|
||||||
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(n_kv_vec)
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||||
|
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||||
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||||
|
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
|
||||||
|
# Mini-Jamba
|
||||||
|
name = name.replace(".moe.", ".feed_forward.")
|
||||||
|
if bid is not None:
|
||||||
|
moe_offset = self.hparams["expert_layer_offset"]
|
||||||
|
moe_period = self.hparams["expert_layer_period"]
|
||||||
|
|
||||||
|
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
|
||||||
|
name = name.replace(".experts.0.", ".")
|
||||||
|
|
||||||
|
# process the experts separately
|
||||||
|
if ".feed_forward.experts." in name:
|
||||||
|
n_experts = self.hparams["num_experts"]
|
||||||
|
|
||||||
|
assert bid is not None
|
||||||
|
|
||||||
|
if self._experts is None:
|
||||||
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
|
self._experts[bid][name] = data_torch
|
||||||
|
|
||||||
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
|
|
||||||
|
# merge the experts into a single 3d tensor
|
||||||
|
for wid in ["down_proj", "gate_proj", "up_proj"]:
|
||||||
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
|
for xid in range(n_experts):
|
||||||
|
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
|
||||||
|
datas.append(self._experts[bid][ename])
|
||||||
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
|
||||||
|
# using the same merged name as qwen2moe
|
||||||
|
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
|
yield new_name, data_torch
|
||||||
|
return
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
if name.endswith(".A_log"):
|
||||||
|
logger.debug("A_log --> A ==> " + new_name)
|
||||||
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
|
yield new_name, data_torch
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
super().write_tensors()
|
||||||
|
|
||||||
|
if self._experts is not None:
|
||||||
|
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||||
|
experts = [k for d in self._experts for k in d.keys()]
|
||||||
|
if len(experts) > 0:
|
||||||
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
# same as Mamba
|
||||||
|
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||||
|
del n_dims # unused
|
||||||
|
|
||||||
|
return bid is not None and new_name in (
|
||||||
|
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
|
||||||
|
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
gguf.MODEL_TENSOR.SSM_X,
|
||||||
|
gguf.MODEL_TENSOR.SSM_DT,
|
||||||
|
gguf.MODEL_TENSOR.SSM_A,
|
||||||
|
gguf.MODEL_TENSOR.SSM_D,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@Model.register("CohereForCausalLM")
|
@Model.register("CohereForCausalLM")
|
||||||
class CommandR2Model(Model):
|
class CommandR2Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||||
|
|
56
ggml-metal.m
56
ggml-metal.m
|
@ -187,6 +187,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
|
@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return ctx->support_simdgroup_reduction &&
|
return ctx->support_simdgroup_reduction &&
|
||||||
|
@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
||||||
// ggml_is_contiguous(src1), src1->name);
|
// ggml_is_contiguous(src1), src1->name);
|
||||||
//}
|
//}
|
||||||
|
//if (src2) {
|
||||||
|
// GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22,
|
||||||
|
// ggml_is_contiguous(src2), src2->name);
|
||||||
|
//}
|
||||||
//if (dst) {
|
//if (dst) {
|
||||||
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
||||||
// dst->name);
|
// dst->name);
|
||||||
|
@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
{
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
//pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||||
|
|
||||||
|
//[encoder setComputePipelineState:pipeline];
|
||||||
|
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
//[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
|
//[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
|
//[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
//[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||||
|
//[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||||
|
//[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||||
|
//[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||||
|
//[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||||
|
//[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||||
|
//[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||||
|
//[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||||
|
//[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||||
|
//[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||||
|
//[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||||
|
//[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||||
|
//[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||||
|
//[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||||
|
//[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||||
|
//[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||||
|
//[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||||
|
//[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||||
|
//[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||||
|
//[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||||
|
//[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||||
|
//[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||||
|
//[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||||
|
|
||||||
|
//if (bcast_row) {
|
||||||
|
// const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
|
// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
//} else {
|
||||||
|
// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||||
|
|
||||||
|
// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
//}
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
|
|
@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||||
|
|
||||||
|
kernel void kernel_ssm_conv_f32(
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device const float * src2,
|
||||||
|
device const int32_t * src3,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne20,
|
||||||
|
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb21,
|
||||||
|
constant uint64_t & nb22,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
|
94
ggml.c
94
ggml.c
|
@ -7094,19 +7094,18 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||||
GGML_ASSERT(ggml_is_3d(s));
|
GGML_ASSERT(ggml_is_3d(s));
|
||||||
GGML_ASSERT(ggml_is_matrix(x));
|
GGML_ASSERT(ggml_is_matrix(x));
|
||||||
GGML_ASSERT(ggml_is_matrix(c));
|
GGML_ASSERT(ggml_is_matrix(c));
|
||||||
GGML_ASSERT(ggml_is_matrix(sq));
|
GGML_ASSERT(ggml_is_vector(sq));
|
||||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
const int64_t d_conv = c->ne[0];
|
const int64_t d_conv = c->ne[0];
|
||||||
const int64_t d_inner = c->ne[1];
|
const int64_t d_inner = c->ne[1];
|
||||||
const int64_t n_tokens = x->ne[1];
|
const int64_t n_tokens = x->ne[1];
|
||||||
const int64_t n_kv = s->ne[2];
|
const int64_t n_rs = s->ne[2];
|
||||||
|
|
||||||
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
||||||
GGML_ASSERT( s->ne[1] == d_inner);
|
GGML_ASSERT( s->ne[1] == d_inner);
|
||||||
GGML_ASSERT( x->ne[0] == d_inner);
|
GGML_ASSERT( x->ne[0] == d_inner);
|
||||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
GGML_ASSERT(sq->ne[0] == n_tokens);
|
||||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
|
@ -7115,8 +7114,8 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_CONV;
|
result->op = GGML_OP_SSM_CONV;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -7169,7 +7168,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
|
@ -16241,9 +16240,9 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const int nc = src2->ne[0]; // d_conv
|
const int nc = src2->ne[0]; // d_conv
|
||||||
const int nr = src0->ne[1]; // d_inner
|
const int nr = src0->ne[1]; // d_inner
|
||||||
const int n_t = src1->ne[1]; // n_tokens
|
const int n_t = src1->ne[1]; // n_tokens
|
||||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
|
@ -16260,10 +16259,12 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
const int32_t * sq = src3->data; // {n_tokens}
|
||||||
|
|
||||||
|
if (n_rs > 1) {
|
||||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
// multiple sequences means it's hard to know when it's the first time a state is read,
|
||||||
// so copy them all over to the destination, just to be sure.
|
// so copy them all over to the destination, just to be sure.
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
||||||
// can't use memcpy because of d_conv vs d_conv - 1
|
// can't use memcpy because of d_conv vs d_conv - 1
|
||||||
|
@ -16277,19 +16278,19 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
int32_t sq_i = sq[i2];
|
||||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
|
||||||
float * s0; // {d_conv - 1, d_inner, n_kv}
|
float * s0; // {d_conv - 1, d_inner, n_rs}
|
||||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
||||||
int ne0s0;
|
int ne0s0;
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||||
|
|
||||||
// avoid needing to copy the state for the first token
|
// avoid needing to copy the state for the first token
|
||||||
if (i2 == 0) {
|
if (i2 == 0) {
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
|
||||||
ne0s0 = src0->ne[0];
|
ne0s0 = src0->ne[0];
|
||||||
} else {
|
} else {
|
||||||
// the source is the last (d_conv - 1) columns of the destination
|
// the source is the last (d_conv - 1) columns of the destination
|
||||||
|
@ -16307,18 +16308,6 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
s[(nc - 1) + i1*nc] = x0[i1];
|
s[(nc - 1) + i1*nc] = x0[i1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// it seems a little faster when this is separate from the state shift
|
// it seems a little faster when this is separate from the state shift
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
// rowwise dot product
|
// rowwise dot product
|
||||||
|
@ -16370,7 +16359,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // d_inner
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||||
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
@ -16379,6 +16368,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||||
// required for the dot product between s and C, and when copying the states
|
// required for the dot product between s and C, and when copying the states
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// required for per-sequence offsets for states
|
// required for per-sequence offsets for states
|
||||||
|
@ -16394,10 +16384,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
const int32_t * sq = src6->data; // {n_tokens}
|
||||||
|
|
||||||
|
if (n_rs > 1) {
|
||||||
// it's hard to know if the source states have already been copied
|
// it's hard to know if the source states have already been copied
|
||||||
// when there are multiple, so copy them already.
|
// when there are multiple, so copy them already.
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
||||||
memcpy(s, s0, nc*ir*sizeof(float));
|
memcpy(s, s0, nc*ir*sizeof(float));
|
||||||
|
@ -16405,21 +16397,21 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
int32_t sq_i = sq[i2];
|
||||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
|
||||||
float * s0;
|
float * s0;
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||||
|
|
||||||
// avoid needing to copy the state for the first token
|
// avoid needing to copy the state for the first token
|
||||||
if (i2 == 0) {
|
if (i2 == 0) {
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
|
||||||
} else {
|
} else {
|
||||||
// otherwise the source is the same as the destination
|
// otherwise the source is the same as the destination
|
||||||
s0 = s;
|
s0 = s;
|
||||||
|
@ -16442,18 +16434,6 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
}
|
}
|
||||||
y[i1] = sumf;
|
y[i1] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
GEMMA = auto()
|
GEMMA = auto()
|
||||||
STARCODER2 = auto()
|
STARCODER2 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
|
JAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
DBRX = auto()
|
DBRX = auto()
|
||||||
|
@ -182,7 +183,10 @@ class MODEL_TENSOR(IntEnum):
|
||||||
SSM_CONV1D = auto()
|
SSM_CONV1D = auto()
|
||||||
SSM_X = auto()
|
SSM_X = auto()
|
||||||
SSM_DT = auto()
|
SSM_DT = auto()
|
||||||
|
SSM_DT_NORM = auto()
|
||||||
SSM_A = auto()
|
SSM_A = auto()
|
||||||
|
SSM_B_NORM = auto()
|
||||||
|
SSM_C_NORM = auto()
|
||||||
SSM_D = auto()
|
SSM_D = auto()
|
||||||
SSM_OUT = auto()
|
SSM_OUT = auto()
|
||||||
|
|
||||||
|
@ -216,6 +220,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.GEMMA: "gemma",
|
MODEL_ARCH.GEMMA: "gemma",
|
||||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
|
MODEL_ARCH.JAMBA: "jamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
MODEL_ARCH.DBRX: "dbrx",
|
MODEL_ARCH.DBRX: "dbrx",
|
||||||
|
@ -263,7 +268,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
|
||||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||||
|
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
|
||||||
|
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
|
||||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||||
}
|
}
|
||||||
|
@ -682,6 +690,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.SSM_D,
|
MODEL_TENSOR.SSM_D,
|
||||||
MODEL_TENSOR.SSM_OUT,
|
MODEL_TENSOR.SSM_OUT,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.JAMBA: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.SSM_IN,
|
||||||
|
MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
MODEL_TENSOR.SSM_X,
|
||||||
|
MODEL_TENSOR.SSM_DT,
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM,
|
||||||
|
MODEL_TENSOR.SSM_A,
|
||||||
|
MODEL_TENSOR.SSM_B_NORM,
|
||||||
|
MODEL_TENSOR.SSM_C_NORM,
|
||||||
|
MODEL_TENSOR.SSM_D,
|
||||||
|
MODEL_TENSOR.SSM_OUT,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.XVERSE: [
|
MODEL_ARCH.XVERSE: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
@ -385,8 +385,11 @@ class GGUFWriter:
|
||||||
def add_head_count(self, count: int) -> None:
|
def add_head_count(self, count: int) -> None:
|
||||||
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_head_count_kv(self, count: int) -> None:
|
def add_head_count_kv(self, count: int | Sequence[int]) -> None:
|
||||||
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
if isinstance(count, int):
|
||||||
|
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
||||||
|
else:
|
||||||
|
self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_key_length(self, length: int) -> None:
|
def add_key_length(self, length: int) -> None:
|
||||||
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
|
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
|
@ -206,6 +206,8 @@ class TensorNameMap:
|
||||||
"h.{bid}.ln_2", # gpt2
|
"h.{bid}.ln_2", # gpt2
|
||||||
"model.layers.{bid}.ffn_norm", # internlm2
|
"model.layers.{bid}.ffn_norm", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
|
"model.layers.{bid}.pre_ff_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP: (
|
MODEL_TENSOR.FFN_GATE_INP: (
|
||||||
|
@ -214,6 +216,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.gate", # qwen2moe
|
"model.layers.{bid}.mlp.gate", # qwen2moe
|
||||||
"transformer.decoder_layer.{bid}.router", # Grok
|
"transformer.decoder_layer.{bid}.router", # Grok
|
||||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||||
|
"model.layers.{bid}.feed_forward.router", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
@ -245,6 +248,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||||
|
"model.layers.{bid}.feed_forward.up_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_EXP: (
|
MODEL_TENSOR.FFN_UP_EXP: (
|
||||||
|
@ -274,6 +278,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||||
|
"model.layers.{bid}.feed_forward.gate_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||||
|
@ -309,6 +314,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.c_proj", # starcoder2
|
"model.layers.{bid}.mlp.c_proj", # starcoder2
|
||||||
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
|
||||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||||
|
"model.layers.{bid}.feed_forward.down_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||||
|
@ -350,38 +356,59 @@ class TensorNameMap:
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
"model.layers.{bid}.in_proj",
|
"model.layers.{bid}.in_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.in_proj",
|
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.in_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_CONV1D: (
|
MODEL_TENSOR.SSM_CONV1D: (
|
||||||
"model.layers.{bid}.conv1d",
|
"model.layers.{bid}.conv1d", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.conv1d",
|
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||||
|
"model.layers.{bid}.mamba.conv1d", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_X: (
|
MODEL_TENSOR.SSM_X: (
|
||||||
"model.layers.{bid}.x_proj",
|
"model.layers.{bid}.x_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.x_proj",
|
"backbone.layers.{bid}.mixer.x_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.x_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_DT: (
|
MODEL_TENSOR.SSM_DT: (
|
||||||
"model.layers.{bid}.dt_proj",
|
"model.layers.{bid}.dt_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.dt_proj",
|
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.dt_proj", # jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.dt_layernorm", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_A: (
|
MODEL_TENSOR.SSM_A: (
|
||||||
"model.layers.{bid}.A_log",
|
"model.layers.{bid}.A_log", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.A_log",
|
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||||
|
"model.layers.{bid}.mamba.A_log", # jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_B_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.b_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_C_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.c_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_D: (
|
MODEL_TENSOR.SSM_D: (
|
||||||
"model.layers.{bid}.D",
|
"model.layers.{bid}.D", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.D",
|
"backbone.layers.{bid}.mixer.D", # mamba
|
||||||
|
"model.layers.{bid}.mamba.D", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_OUT: (
|
MODEL_TENSOR.SSM_OUT: (
|
||||||
"model.layers.{bid}.out_proj",
|
"model.layers.{bid}.out_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.out_proj",
|
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
72
llama.h
72
llama.h
|
@ -546,6 +546,12 @@ extern "C" {
|
||||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||||
|
|
||||||
|
// Rebuild and check the validity of the recurrent state cache's tree of sequences.
|
||||||
|
// (slow, use only for debugging purposes)
|
||||||
|
// Returns whether or not the rs cache was valid.
|
||||||
|
// The errors are always corrected, but only logged when debug is true.
|
||||||
|
LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug);
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||||
|
@ -553,36 +559,62 @@ extern "C" {
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
// Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
|
||||||
|
LLAMA_API void llama_cache_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
|
||||||
|
struct llama_context * ctx),
|
||||||
|
"use llama_cache_clear instead");
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
|
||||||
// seq_id < 0 : match any sequence
|
// seq_id < 0 : match any sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
// Returns n_past (one more than the largest remaining pos in the seq_id)
|
||||||
|
// which is only meaningful to handle for partial removals.
|
||||||
|
LLAMA_API llama_pos llama_cache_seq_rm(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1);
|
llama_pos p1);
|
||||||
|
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_cache_seq_rm instead, and handle its return value for partial removals");
|
||||||
|
|
||||||
// Copy all tokens that belong to the specified sequence to another sequence
|
// Copy all tokens that belong to the specified sequence to another sequence
|
||||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_cp(
|
// Returns n_past (one more than the largest remaining pos in the destination seq_id)
|
||||||
|
// which is only meaningful to handle when partially copying.
|
||||||
|
LLAMA_API llama_pos llama_cache_seq_cp(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id_src,
|
llama_seq_id seq_id_src,
|
||||||
llama_seq_id seq_id_dst,
|
llama_seq_id seq_id_dst,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1);
|
llama_pos p1);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_cache_seq_cp instead, and handle its return value for partial copies");
|
||||||
|
|
||||||
// Removes all tokens that do not belong to the specified sequence
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
LLAMA_API void llama_kv_cache_seq_keep(
|
LLAMA_API void llama_cache_seq_keep(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_cache_seq_keep instead");
|
||||||
|
|
||||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
@ -590,12 +622,19 @@ extern "C" {
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_add(
|
LLAMA_API void llama_cache_seq_add(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta);
|
llama_pos delta);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta),
|
||||||
|
"use llama_cache_seq_add instead");
|
||||||
|
|
||||||
// Integer division of the positions by factor of `d > 1`
|
// Integer division of the positions by factor of `d > 1`
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
@ -603,17 +642,28 @@ extern "C" {
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_div(
|
LLAMA_API void llama_cache_seq_div(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d);
|
int d);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d),
|
||||||
|
"use llama_cache_seq_div instead");
|
||||||
|
|
||||||
// Returns the largest position present in the KV cache for the specified sequence
|
// Returns the largest position present in the KV and/or RS cache for the specified sequence
|
||||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
LLAMA_API llama_pos llama_cache_seq_pos_max(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
|
||||||
|
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
|
|
|
@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_SSM_CONV
|
||||||
|
struct test_ssm_conv : public test_case {
|
||||||
|
const ggml_type type_s;
|
||||||
|
const ggml_type type_x;
|
||||||
|
const ggml_type type_c;
|
||||||
|
const ggml_type type_sq;
|
||||||
|
const int64_t d_inner;
|
||||||
|
const int64_t d_conv;
|
||||||
|
const int64_t n_tokens;
|
||||||
|
const int64_t n_rs;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_ssm_conv(ggml_type type_s = GGML_TYPE_F32,
|
||||||
|
ggml_type type_x = GGML_TYPE_F32,
|
||||||
|
ggml_type type_c = GGML_TYPE_F32,
|
||||||
|
ggml_type type_sq = GGML_TYPE_I32,
|
||||||
|
int64_t d_inner = 10,
|
||||||
|
int64_t d_conv = 10,
|
||||||
|
int64_t n_tokens = 10,
|
||||||
|
int64_t n_rs = 10)
|
||||||
|
: type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs);
|
||||||
|
ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens);
|
||||||
|
ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner);
|
||||||
|
ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens);
|
||||||
|
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_I32) {
|
||||||
|
// pos
|
||||||
|
std::vector<int> data(t->ne[0]);
|
||||||
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
data[i] = rand() % n_rs;
|
||||||
|
}
|
||||||
|
ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int));
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
enum llm_norm_type {
|
enum llm_norm_type {
|
||||||
LLM_NORM,
|
LLM_NORM,
|
||||||
LLM_NORM_RMS,
|
LLM_NORM_RMS,
|
||||||
|
@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_ssm_conv());
|
||||||
|
|
||||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||||
#if 0
|
#if 0
|
||||||
test_cases.emplace_back(new test_llama(1));
|
test_cases.emplace_back(new test_llama(1));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue