tts : fix tensor shapes

This commit is contained in:
Georgi Gerganov 2024-12-16 16:48:22 +02:00
parent c096bbd8dd
commit d1ef627c51
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 101 additions and 75 deletions

View file

@ -326,6 +326,8 @@ class Model:
gguf.MODEL_TENSOR.TIME_MIX_W2,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
gguf.MODEL_TENSOR.POS_NET_NORM1,
gguf.MODEL_TENSOR.POS_NET_NORM2,
)
)
or not new_name.endswith(".weight")
@ -2060,6 +2062,8 @@ class WavTokenizerDecModel(Model):
self.gguf_writer.add_posnet_length (self.hparams["n_embd_posnet"])
self.gguf_writer.add_convnext_length (self.hparams["n_embd_convnext"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"])
self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"])
self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"])
@Model.register("Qwen2MoeForCausalLM")

View file

@ -98,6 +98,13 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'):
if new_key.endswith("gamma"):
new_key = new_key.replace("gamma", "gamma.weight")
# convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.pos_net") or new_key.startswith("backbone.embed.bias")):
value = value.unsqueeze(1)
if new_key.endswith("dwconv.bias"):
value = value.unsqueeze(1)
size_mb = value.element_size() * value.nelement() / (1024 * 1024)
print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
@ -154,6 +161,8 @@ config = {
"vocab_size": 4096,
"n_head": 1,
"layer_norm_epsilon": 1e-6,
"group_norm_epsilon": 1e-6,
"group_norm_groups": 32,
"max_position_embeddings": 8192, # ?
"num_hidden_layers": 12
}

View file

@ -125,6 +125,8 @@ class Keys:
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"

View file

@ -631,9 +631,6 @@ class GGUFWriter:
def add_embedding_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_embedding_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_features_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
@ -739,6 +736,12 @@ class GGUFWriter:
def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
def add_group_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
def add_group_norm_groups(self, value: int) -> None:
self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

View file

@ -310,6 +310,8 @@ enum llm_kv {
LLM_KV_ATTENTION_VALUE_LENGTH,
LLM_KV_ATTENTION_LAYERNORM_EPS,
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_ATTENTION_GROUPNORM_EPS,
LLM_KV_ATTENTION_GROUPNORM_GROUPS,
LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
@ -430,6 +432,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
@ -2571,6 +2575,9 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
uint32_t n_norm_groups;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
@ -6464,6 +6471,8 @@ static void llm_load_hparams(
case LLM_ARCH_WAVTOKENIZER_DEC:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps);
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
} break;
default: (void)0;
}
@ -9575,79 +9584,79 @@ static bool llm_load_tensors(
model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd_posnet}, 0);
model.conv_1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, n_embd_features, n_embd_posnet}, 0);
model.conv_1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {n_embd_posnet}, 0);
model.conv_1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, n_embd_posnet}, 0);
model.posnet_0_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 0), {n_embd_posnet}, 0);
model.posnet_0_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 0), {n_embd_posnet}, 0);
model.posnet_0_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 0), {1, n_embd_posnet}, 0);
model.posnet_0_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 0), {1, n_embd_posnet}, 0);
model.posnet_0_conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", 0), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_0_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 0), {n_embd_posnet}, 0);
model.posnet_0_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 0), {1, n_embd_posnet}, 0);
model.posnet_0_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 0), {n_embd_posnet}, 0);
model.posnet_0_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 0), {n_embd_posnet}, 0);
model.posnet_0_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 0), {1, n_embd_posnet}, 0);
model.posnet_0_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 0), {1, n_embd_posnet}, 0);
model.posnet_0_conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", 0), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_0_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 0), {n_embd_posnet}, 0);
model.posnet_0_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 0), {1, n_embd_posnet}, 0);
model.posnet_1_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 1), {n_embd_posnet}, 0);
model.posnet_1_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 1), {n_embd_posnet}, 0);
model.posnet_1_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 1), {1, n_embd_posnet}, 0);
model.posnet_1_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 1), {1, n_embd_posnet}, 0);
model.posnet_1_conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", 1), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_1_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 1), {n_embd_posnet}, 0);
model.posnet_1_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 1), {1, n_embd_posnet}, 0);
model.posnet_1_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 1), {n_embd_posnet}, 0);
model.posnet_1_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 1), {n_embd_posnet}, 0);
model.posnet_1_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 1), {1, n_embd_posnet}, 0);
model.posnet_1_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 1), {1, n_embd_posnet}, 0);
model.posnet_1_conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", 1), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_1_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 1), {n_embd_posnet}, 0);
model.posnet_1_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 1), {1, n_embd_posnet}, 0);
model.posnet_2_attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 2), {1, n_embd_posnet}, 0);
model.posnet_2_attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 2), {1, n_embd_posnet}, 0);
model.posnet_2_attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", 2), {1, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_2_attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", 2), {1, n_embd_posnet}, 0);
model.posnet_2_attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", 2), {1, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_2_attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", 2), {1, n_embd_posnet}, 0);
model.posnet_2_attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", 2), {1, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_2_attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", 2), {1, n_embd_posnet}, 0);
model.posnet_2_attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", 2), {1, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_2_attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", 2), {n_embd_posnet}, 0);
model.posnet_2_attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", 2), {1, n_embd_posnet}, 0);
model.posnet_3_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 3), {n_embd_posnet}, 0);
model.posnet_3_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 3), {n_embd_posnet}, 0);
model.posnet_3_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 3), {1, n_embd_posnet}, 0);
model.posnet_3_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 3), {1, n_embd_posnet}, 0);
model.posnet_3_conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", 3), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_3_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 3), {n_embd_posnet}, 0);
model.posnet_3_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 3), {1, n_embd_posnet}, 0);
model.posnet_3_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 3), {n_embd_posnet}, 0);
model.posnet_3_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 3), {n_embd_posnet}, 0);
model.posnet_3_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 3), {1, n_embd_posnet}, 0);
model.posnet_3_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 3), {1, n_embd_posnet}, 0);
model.posnet_3_conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", 3), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_3_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 3), {n_embd_posnet}, 0);
model.posnet_3_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 3), {1, n_embd_posnet}, 0);
model.posnet_4_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 4), {n_embd_posnet}, 0);
model.posnet_4_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 4), {n_embd_posnet}, 0);
model.posnet_4_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 4), {1, n_embd_posnet}, 0);
model.posnet_4_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 4), {1, n_embd_posnet}, 0);
model.posnet_4_conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", 4), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_4_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 4), {n_embd_posnet}, 0);
model.posnet_4_conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", 4), {1, n_embd_posnet}, 0);
model.posnet_4_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 4), {n_embd_posnet}, 0);
model.posnet_4_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 4), {n_embd_posnet}, 0);
model.posnet_4_norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", 4), {1, n_embd_posnet}, 0);
model.posnet_4_norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", 4), {1, n_embd_posnet}, 0);
model.posnet_4_conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", 4), {3, n_embd_posnet, n_embd_posnet}, 0);
model.posnet_4_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 4), {n_embd_posnet}, 0);
model.posnet_4_conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", 4), {1, n_embd_posnet}, 0);
model.posnet_5_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 5), {n_embd_posnet}, 0);
model.posnet_5_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 5), {n_embd_posnet}, 0);
model.posnet_5_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 5), {1, n_embd_posnet}, 0);
model.posnet_5_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 5), {1, n_embd_posnet}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.convnext_dw = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "weight", i), {7, 1, n_embd_convnext}, 0);
layer.convnext_dw_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "bias", i), {n_embd_convnext}, 0);
layer.convnext_dw_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "bias", i), {1, n_embd_convnext}, 0);
layer.convnext_norm = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "weight", i), {n_embd_convnext}, 0);
layer.convnext_norm_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "bias", i), {n_embd_convnext}, 0);
@ -10033,9 +10042,8 @@ static struct ggml_tensor * llm_build_norm(
case LLM_NORM_RMS: cur = ggml_rms_norm (ctx, cur, hparams.f_norm_rms_eps); break;
case LLM_NORM_GROUP:
{
// TODO: these reshapes should be removed, fix ggml_group_norm
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
cur = ggml_group_norm(ctx, cur, 32, 1e-6); // TODO: add groups, eps params
cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
cur = ggml_reshape_2d(ctx, cur, cur->ne[0], cur->ne[2]);
} break;
}
@ -17256,31 +17264,31 @@ struct llm_build_context {
cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
cur = ggml_conv_1d_ph(ctx0, model.conv_1d, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.conv_1d_b, 1, model.conv_1d_b->ne[0]));
cur = ggml_add(ctx0, cur, model.conv_1d_b);
inpL = cur;
// resnet block 0
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_0_norm1, 1, model.posnet_0_norm1->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_0_norm1_b, 1, model.posnet_0_norm1_b->ne[0]),
model.posnet_0_norm1,
model.posnet_0_norm1_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_0_conv1, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_0_conv1_b, 1, model.posnet_0_conv1_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_0_conv1_b);
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_0_norm2, 1, model.posnet_0_norm2->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_0_norm2_b, 1, model.posnet_0_norm2_b->ne[0]),
model.posnet_0_norm2,
model.posnet_0_norm2_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_0_conv2, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_0_conv2_b, 1, model.posnet_0_conv2_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_0_conv2_b);
cur = ggml_add(ctx0, cur, inpL);
}
@ -17290,24 +17298,24 @@ struct llm_build_context {
// resnet block 1
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_1_norm1, 1, model.posnet_1_norm1->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_1_norm1_b, 1, model.posnet_1_norm1_b->ne[0]),
model.posnet_1_norm1,
model.posnet_1_norm1_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_1_conv1, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_1_conv1_b, 1, model.posnet_1_conv1_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_1_conv1_b);
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_1_norm2, 1, model.posnet_1_norm2->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_1_norm2_b, 1, model.posnet_1_norm2_b->ne[0]),
model.posnet_1_norm2,
model.posnet_1_norm2_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_1_conv2, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_1_conv2_b, 1, model.posnet_1_conv2_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_1_conv2_b);
cur = ggml_add(ctx0, cur, inpL);
}
@ -17317,8 +17325,8 @@ struct llm_build_context {
// attention block
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_2_attn_norm, 1, model.posnet_2_attn_norm->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_2_attn_norm_b, 1, model.posnet_2_attn_norm_b->ne[0]),
model.posnet_2_attn_norm,
model.posnet_2_attn_norm_b,
LLM_NORM_GROUP, cb, 0);
struct ggml_tensor * q;
@ -17329,9 +17337,9 @@ struct llm_build_context {
k = ggml_conv_1d_ph(ctx0, model.posnet_2_attn_k, cur, 1, 1);
v = ggml_conv_1d_ph(ctx0, model.posnet_2_attn_v, cur, 1, 1);
q = ggml_add(ctx0, q, ggml_reshape_2d(ctx0, model.posnet_2_attn_q_b, 1, model.posnet_2_attn_q_b->ne[0]));
k = ggml_add(ctx0, k, ggml_reshape_2d(ctx0, model.posnet_2_attn_k_b, 1, model.posnet_2_attn_k_b->ne[0]));
v = ggml_add(ctx0, v, ggml_reshape_2d(ctx0, model.posnet_2_attn_v_b, 1, model.posnet_2_attn_v_b->ne[0]));
q = ggml_add(ctx0, q, model.posnet_2_attn_q_b);
k = ggml_add(ctx0, k, model.posnet_2_attn_k_b);
v = ggml_add(ctx0, v, model.posnet_2_attn_v_b);
q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
@ -17343,7 +17351,7 @@ struct llm_build_context {
cur = ggml_mul_mat(ctx0, kq, v);
cur = ggml_conv_1d_ph(ctx0, model.posnet_2_attn_o, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_2_attn_o_b, 1, model.posnet_2_attn_o_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_2_attn_o_b);
cur = ggml_add(ctx0, cur, inpL);
}
@ -17353,24 +17361,24 @@ struct llm_build_context {
// resnet block 3
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_3_norm1, 1, model.posnet_3_norm1->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_3_norm1_b, 1, model.posnet_3_norm1_b->ne[0]),
model.posnet_3_norm1,
model.posnet_3_norm1_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_3_conv1, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_3_conv1_b, 1, model.posnet_3_conv1_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_3_conv1_b);
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_3_norm2, 1, model.posnet_3_norm2->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_3_norm2_b, 1, model.posnet_3_norm2_b->ne[0]),
model.posnet_3_norm2,
model.posnet_3_norm2_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_3_conv2, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_3_conv2_b, 1, model.posnet_3_conv2_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_3_conv2_b);
cur = ggml_add(ctx0, cur, inpL);
}
@ -17380,24 +17388,24 @@ struct llm_build_context {
// resnet block 4
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_4_norm1, 1, model.posnet_4_norm1->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_4_norm1_b, 1, model.posnet_4_norm1_b->ne[0]),
model.posnet_4_norm1,
model.posnet_4_norm1_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_4_conv1, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_4_conv1_b, 1, model.posnet_4_conv1_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_4_conv1_b);
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_4_norm2, 1, model.posnet_4_norm2->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_4_norm2_b, 1, model.posnet_4_norm2_b->ne[0]),
model.posnet_4_norm2,
model.posnet_4_norm2_b,
LLM_NORM_GROUP, cb, 0);
cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
cur = ggml_conv_1d_ph(ctx0, model.posnet_4_conv2, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.posnet_4_conv2_b, 1, model.posnet_4_conv2_b->ne[0]));
cur = ggml_add(ctx0, cur, model.posnet_4_conv2_b);
cur = ggml_add(ctx0, cur, inpL);
}
@ -17405,8 +17413,8 @@ struct llm_build_context {
// normalize block 5
{
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_5_norm, 1, model.posnet_5_norm->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_5_norm_b, 1, model.posnet_5_norm_b->ne[0]),
model.posnet_5_norm,
model.posnet_5_norm_b,
LLM_NORM_GROUP, cb, 0);
}
@ -17425,7 +17433,7 @@ struct llm_build_context {
cur = inpL;
cur = ggml_conv_1d_dw_ph(ctx0, model.layers[il].convnext_dw, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.layers[il].convnext_dw_b, 1, model.layers[il].convnext_dw_b->ne[0]));
cur = ggml_add(ctx0, cur, model.layers[il].convnext_dw_b);
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));