start doing the instructions but not finished. This probably doesnt compile
This commit is contained in:
parent
84ab887349
commit
859e70899a
3 changed files with 20 additions and 2 deletions
16
llama.cpp
16
llama.cpp
|
@ -101,6 +101,8 @@ struct llama_context {
|
|||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
// input embedding (1-dimensional array: [n_embd])
|
||||
std::vector<float> embedding;
|
||||
bool logits_all = false;
|
||||
};
|
||||
|
||||
|
@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.f16_kv =*/ false,
|
||||
/*.logits_all =*/ false,
|
||||
/*.vocab_only =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
@ -127,7 +130,8 @@ static bool llama_model_load(
|
|||
int n_ctx,
|
||||
int n_parts,
|
||||
ggml_type memory_type,
|
||||
bool vocab_only) {
|
||||
bool vocab_only,
|
||||
bool embedding) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
@ -594,6 +598,10 @@ static bool llama_model_load(
|
|||
|
||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
||||
|
||||
if (embedding){
|
||||
lctx.embedding.reserve(lctx.model.hparams.n_embd);
|
||||
}
|
||||
|
||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
return true;
|
||||
|
@ -1433,7 +1441,7 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
|
||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
|
@ -1508,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||
return ctx->logits.data();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embedding.data();
|
||||
}
|
||||
|
||||
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
||||
if (token >= llama_n_vocab(ctx)) {
|
||||
return nullptr;
|
||||
|
|
5
llama.h
5
llama.h
|
@ -53,6 +53,7 @@ extern "C" {
|
|||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool embedding; // embedding mode only
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
@ -109,6 +110,10 @@ extern "C" {
|
|||
// Cols: n_vocab
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||
|
||||
|
|
1
main.cpp
1
main.cpp
|
@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
|
|||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.embedding = params.embedding;
|
||||
|
||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue