the shapes for init model of gqa models was wrong
This commit is contained in:
xaedes 2023-10-02 14:44:21 +02:00
parent f5ef5cfb18
commit f80994b09c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
assert_shape_1d(layer.attention_norm, hparams.n_embd); assert_shape_1d(layer.attention_norm, hparams.n_embd);
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa());
assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa());
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
assert_shape_1d(layer.ffn_norm, hparams.n_embd); assert_shape_1d(layer.ffn_norm, hparams.n_embd);
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff); assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);