replace llama API functions to get model tensors by one function to get model tensor by name

LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
This commit is contained in:
xaedes 2023-08-16 21:36:40 +02:00
parent 39a2d15461
commit 1151653b15
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 38 additions and 92 deletions

View file

@ -352,27 +352,44 @@ void init_model(struct llama_model * input, struct my_llama_model * model, uint3
const uint32_t n_ff = get_n_ff(&hparams); const uint32_t n_ff = get_n_ff(&hparams);
model->tok_embeddings = llama_get_model_tok_embeddings(input); model->tok_embeddings = llama_get_model_tensor(input, "tok_embeddings.weight");
model->norm = llama_get_model_norm(input); model->norm = llama_get_model_tensor(input, "norm.weight");
model->output = llama_get_model_output(input); model->output = llama_get_model_tensor(input, "output.weight");
model->layers.resize(n_layer); model->layers.resize(n_layer);
char name[GGML_MAX_NAME];
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
struct llama_layer * ilayer = llama_get_layer_from_model(input, i); struct llama_layer * ilayer = llama_get_layer_from_model(input, i);
auto & layer = model->layers[i]; auto & layer = model->layers[i];
layer.attention_norm = llama_get_layer_attention_norm(ilayer); snprintf(name, GGML_MAX_NAME, "layers.%d.attention_norm.weight", i);
layer.attention_norm = llama_get_model_tensor(input, name);
layer.wq = llama_get_layer_wq(ilayer); snprintf(name, GGML_MAX_NAME, "layers.%d.attention.wq.weight", i);
layer.wk = llama_get_layer_wk(ilayer); layer.wq = llama_get_model_tensor(input, name);
layer.wv = llama_get_layer_wv(ilayer);
layer.wo = llama_get_layer_wo(ilayer);
layer.ffn_norm = llama_get_layer_ffn_norm(ilayer); snprintf(name, GGML_MAX_NAME, "layers.%d.attention.wk.weight", i);
layer.wk = llama_get_model_tensor(input, name);
layer.w1 = llama_get_layer_w1(ilayer); snprintf(name, GGML_MAX_NAME, "layers.%d.attention.wv.weight", i);
layer.w2 = llama_get_layer_w2(ilayer); layer.wv = llama_get_model_tensor(input, name);
layer.w3 = llama_get_layer_w3(ilayer);
snprintf(name, GGML_MAX_NAME, "layers.%d.attention.wo.weight", i);
layer.wo = llama_get_model_tensor(input, name);
snprintf(name, GGML_MAX_NAME, "layers.%d.ffn_norm.weight", i);
layer.ffn_norm = llama_get_model_tensor(input, name);
snprintf(name, GGML_MAX_NAME, "layers.%d.feed_forward.w1.weight", i);
layer.w1 = llama_get_model_tensor(input, name);
snprintf(name, GGML_MAX_NAME, "layers.%d.feed_forward.w2.weight", i);
layer.w2 = llama_get_model_tensor(input, name);
snprintf(name, GGML_MAX_NAME, "layers.%d.feed_forward.w3.weight", i);
layer.w3 = llama_get_model_tensor(input, name);
} }
} }

View file

@ -4213,62 +4213,8 @@ int llama_get_vocab(
return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity); return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity);
} }
struct llama_layer * llama_get_layer_from_model( struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
struct llama_model * model, return ggml_get_tensor(model->ctx, name);
int layer_idx) {
if (layer_idx < 0 || layer_idx >= model->hparams.n_layer) {
return NULL;
} else {
return &model->layers[layer_idx];
}
}
struct ggml_tensor * llama_get_model_tok_embeddings(struct llama_model * model) {
return model->tok_embeddings;
}
struct ggml_tensor * llama_get_model_norm(struct llama_model * model) {
return model->norm;
}
struct ggml_tensor * llama_get_model_output(struct llama_model * model) {
return model->output;
}
struct ggml_tensor * llama_get_layer_attention_norm(struct llama_layer * layer) {
return layer->attention_norm;
}
struct ggml_tensor * llama_get_layer_wq(struct llama_layer * layer) {
return layer->wq;
}
struct ggml_tensor * llama_get_layer_wk(struct llama_layer * layer) {
return layer->wk;
}
struct ggml_tensor * llama_get_layer_wv(struct llama_layer * layer) {
return layer->wv;
}
struct ggml_tensor * llama_get_layer_wo(struct llama_layer * layer) {
return layer->wo;
}
struct ggml_tensor * llama_get_layer_ffn_norm(struct llama_layer * layer) {
return layer->ffn_norm;
}
struct ggml_tensor * llama_get_layer_w1(struct llama_layer * layer) {
return layer->w1;
}
struct ggml_tensor * llama_get_layer_w2(struct llama_layer * layer) {
return layer->w2;
}
struct ggml_tensor * llama_get_layer_w3(struct llama_layer * layer) {
return layer->w3;
} }
float * llama_get_logits(struct llama_context * ctx) { float * llama_get_logits(struct llama_context * ctx) {

21
llama.h
View file

@ -69,7 +69,6 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_layer;
typedef int llama_token; typedef int llama_token;
@ -357,24 +356,8 @@ extern "C" {
float * scores, float * scores,
int capacity); int capacity);
// Get a llama layer // Get a llama model tensor
LLAMA_API struct llama_layer * llama_get_layer_from_model( LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
struct llama_model * model,
int layer);
LLAMA_API struct ggml_tensor * llama_get_model_tok_embeddings(struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_model_norm (struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_model_output (struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_layer_attention_norm(struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wq (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wk (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wv (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wo (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_ffn_norm (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w1 (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w2 (struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w3 (struct llama_layer * layer);
// Token logits obtained from the last call to llama_eval() // Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row // The logits for the last token are stored in the last row