add API functions to access llama model tensors

This commit is contained in:
xaedes 2023-08-06 17:28:22 +02:00
parent 3b5515bbe0
commit 316b0707f4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 98 additions and 0 deletions

View file

@ -4147,6 +4147,10 @@ int llama_n_embd_from_model(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }
int llama_n_layer_from_model(const struct llama_model * model) {
return model->hparams.n_layer;
}
int llama_n_vocab(const struct llama_context * ctx) { int llama_n_vocab(const struct llama_context * ctx) {
return ctx->model.vocab.id_to_token.size(); return ctx->model.vocab.id_to_token.size();
} }
@ -4159,6 +4163,10 @@ int llama_n_embd(const struct llama_context * ctx) {
return ctx->model.hparams.n_embd; return ctx->model.hparams.n_embd;
} }
int llama_n_layer(const struct llama_context * ctx) {
return ctx->model.hparams.n_layer;
}
int llama_get_vocab_from_model( int llama_get_vocab_from_model(
const struct llama_model * model, const struct llama_model * model,
const char * * strings, const char * * strings,
@ -4180,6 +4188,70 @@ int llama_get_vocab(
return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity); return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity);
} }
struct llama_layer * llama_get_layer_from_model(
const struct llama_model * model,
int layer_idx) {
if (layer_idx < 0 || layer_idx >= model->hparams.n_layer) {
return NULL;
} else {
return &model->layers[layer_idx];
}
}
struct llama_layer * llama_get_layer(
const struct llama_context * ctx,
int layer_idx) {
return llama_get_layer_from_model(&ctx->model, layer_idx);
}
struct ggml_tensor * llama_get_model_tok_embeddings(const struct llama_model * model) {
return model->tok_embeddings;
}
struct ggml_tensor * llama_get_model_norm(const struct llama_model * model) {
return model->norm;
}
struct ggml_tensor * llama_get_model_output(const struct llama_model * model) {
return model->output;
}
struct ggml_tensor * llama_get_layer_attention_norm(const struct llama_layer * layer) {
return layer->attention_norm;
}
struct ggml_tensor * llama_get_layer_wq(const struct llama_layer * layer) {
return layer->wq;
}
struct ggml_tensor * llama_get_layer_wk(const struct llama_layer * layer) {
return layer->wk;
}
struct ggml_tensor * llama_get_layer_wv(const struct llama_layer * layer) {
return layer->wv;
}
struct ggml_tensor * llama_get_layer_wo(const struct llama_layer * layer) {
return layer->wo;
}
struct ggml_tensor * llama_get_layer_ffn_norm(const struct llama_layer * layer) {
return layer->ffn_norm;
}
struct ggml_tensor * llama_get_layer_w1(const struct llama_layer * layer) {
return layer->w1;
}
struct ggml_tensor * llama_get_layer_w2(const struct llama_layer * layer) {
return layer->w2;
}
struct ggml_tensor * llama_get_layer_w3(const struct llama_layer * layer) {
return layer->w3;
}
float * llama_get_logits(struct llama_context * ctx) { float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data(); return ctx->logits.data();
} }

26
llama.h
View file

@ -69,6 +69,7 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_layer;
typedef int llama_token; typedef int llama_token;
@ -329,10 +330,12 @@ extern "C" {
LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API int llama_n_layer(const struct llama_context * ctx);
LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model); LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model); LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model);
LLAMA_API int llama_n_embd_from_model (const struct llama_model * model); LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
LLAMA_API int llama_n_layer_from_model(const struct llama_model * model);
// Get the vocabulary as output parameters. // Get the vocabulary as output parameters.
// Returns number of results. // Returns number of results.
@ -348,6 +351,29 @@ extern "C" {
float * scores, float * scores,
int capacity); int capacity);
// Get a llama layer
LLAMA_API struct llama_layer * llama_get_layer(
const struct llama_context * ctx,
int layer);
LLAMA_API struct llama_layer * llama_get_layer_from_model(
const struct llama_model * model,
int layer);
LLAMA_API struct ggml_tensor * llama_get_model_tok_embeddings(const struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_model_norm (const struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_model_output (const struct llama_model * model);
LLAMA_API struct ggml_tensor * llama_get_layer_attention_norm(const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wq (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wk (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wv (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_wo (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_ffn_norm (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w1 (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w2 (const struct llama_layer * layer);
LLAMA_API struct ggml_tensor * llama_get_layer_w3 (const struct llama_layer * layer);
// Token logits obtained from the last call to llama_eval() // Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row // The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token // Can be mutated in order to change the probabilities of the next token