From b0ee563748062e11bdba8c8c973bea7a187a5b6d Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 17 Sep 2023 16:43:12 +0200 Subject: [PATCH] assert correct base model tensor shapes --- examples/finetune/finetune.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 17b89a2f8..d0fc48f23 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -250,8 +250,11 @@ static void init_model(struct llama_model * input, struct my_llama_model * model model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM)); model->output = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT)); - model->layers.resize(hparams.n_layer); + assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab); + assert_shape_1d(model->norm, hparams.n_embd); + assert_shape_2d(model->output, hparams.n_embd, hparams.n_vocab); + model->layers.resize(hparams.n_layer); for (uint32_t i = 0; i < hparams.n_layer; ++i) { auto & layer = model->layers[i]; @@ -264,6 +267,16 @@ static void init_model(struct llama_model * input, struct my_llama_model * model layer.w1 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i)); layer.w2 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i)); layer.w3 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i)); + + assert_shape_1d(layer.attention_norm, hparams.n_embd); + assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd); + assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd); + assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd); + assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd); + assert_shape_1d(layer.ffn_norm, hparams.n_embd); + assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff); + assert_shape_2d(layer.w2, hparams.n_ff, hparams.n_embd); + assert_shape_2d(layer.w3, hparams.n_embd, hparams.n_ff); } }