assert correct base model tensor shapes

This commit is contained in:
xaedes 2023-09-17 16:43:12 +02:00
parent 5ed309810e
commit b0ee563748
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -250,8 +250,11 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
model->output = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT));
model->layers.resize(hparams.n_layer);
assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab);
assert_shape_1d(model->norm, hparams.n_embd);
assert_shape_2d(model->output, hparams.n_embd, hparams.n_vocab);
model->layers.resize(hparams.n_layer);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
auto & layer = model->layers[i];
@ -264,6 +267,16 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
layer.w1 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
layer.w2 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
layer.w3 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
assert_shape_1d(layer.attention_norm, hparams.n_embd);
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd);
assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd);
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
assert_shape_1d(layer.ffn_norm, hparams.n_embd);
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);
assert_shape_2d(layer.w2, hparams.n_ff, hparams.n_embd);
assert_shape_2d(layer.w3, hparams.n_embd, hparams.n_ff);
}
}