assert correct base model tensor shapes
This commit is contained in:
parent
5ed309810e
commit
b0ee563748
1 changed files with 14 additions and 1 deletions
|
@ -250,8 +250,11 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
|
|||
model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
|
||||
model->output = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT));
|
||||
|
||||
model->layers.resize(hparams.n_layer);
|
||||
assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab);
|
||||
assert_shape_1d(model->norm, hparams.n_embd);
|
||||
assert_shape_2d(model->output, hparams.n_embd, hparams.n_vocab);
|
||||
|
||||
model->layers.resize(hparams.n_layer);
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
auto & layer = model->layers[i];
|
||||
|
||||
|
@ -264,6 +267,16 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
|
|||
layer.w1 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
layer.w2 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
layer.w3 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
|
||||
|
||||
assert_shape_1d(layer.attention_norm, hparams.n_embd);
|
||||
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
|
||||
assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd);
|
||||
assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd);
|
||||
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
|
||||
assert_shape_1d(layer.ffn_norm, hparams.n_embd);
|
||||
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);
|
||||
assert_shape_2d(layer.w2, hparams.n_ff, hparams.n_embd);
|
||||
assert_shape_2d(layer.w3, hparams.n_embd, hparams.n_ff);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue