start doing the instructions but not finished. This probably doesnt compile
This commit is contained in:
parent
84ab887349
commit
859e70899a
3 changed files with 20 additions and 2 deletions
16
llama.cpp
16
llama.cpp
|
@ -101,6 +101,8 @@ struct llama_context {
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
// input embedding (1-dimensional array: [n_embd])
|
||||||
|
std::vector<float> embedding;
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.f16_kv =*/ false,
|
/*.f16_kv =*/ false,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
|
/*.embedding =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -127,7 +130,8 @@ static bool llama_model_load(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_parts,
|
int n_parts,
|
||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool vocab_only) {
|
bool vocab_only,
|
||||||
|
bool embedding) {
|
||||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
@ -594,6 +598,10 @@ static bool llama_model_load(
|
||||||
|
|
||||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
||||||
|
|
||||||
|
if (embedding){
|
||||||
|
lctx.embedding.reserve(lctx.model.hparams.n_embd);
|
||||||
|
}
|
||||||
|
|
||||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -1433,7 +1441,7 @@ struct llama_context * llama_init_from_file(
|
||||||
|
|
||||||
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
delete ctx;
|
delete ctx;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1508,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits.data();
|
return ctx->logits.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
return ctx->embedding.data();
|
||||||
|
}
|
||||||
|
|
||||||
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
||||||
if (token >= llama_n_vocab(ctx)) {
|
if (token >= llama_n_vocab(ctx)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
5
llama.h
5
llama.h
|
@ -53,6 +53,7 @@ extern "C" {
|
||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
|
bool embedding; // embedding mode only
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||||
|
@ -109,6 +110,10 @@ extern "C" {
|
||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Get the embeddings for the input
|
||||||
|
// shape: [n_embd] (1-dimensional)
|
||||||
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)
|
||||||
|
|
||||||
// Token Id -> String. Uses the vocabulary in the provided context
|
// Token Id -> String. Uses the vocabulary in the provided context
|
||||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
|
|
1
main.cpp
1
main.cpp
|
@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
|
||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.f16_kv = params.memory_f16;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.logits_all = params.perplexity;
|
lparams.logits_all = params.perplexity;
|
||||||
|
lparams.embedding = params.embedding;
|
||||||
|
|
||||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue