start doing the instructions but not finished. This probably doesnt compile

This commit is contained in:
strikingLoo 2023-03-22 17:52:46 -07:00
parent 84ab887349
commit 859e70899a
3 changed files with 20 additions and 2 deletions

View file

@ -101,6 +101,8 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;
bool logits_all = false; bool logits_all = false;
}; };
@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false, /*.f16_kv =*/ false,
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.embedding =*/ false,
}; };
return result; return result;
@ -127,7 +130,8 @@ static bool llama_model_load(
int n_ctx, int n_ctx,
int n_parts, int n_parts,
ggml_type memory_type, ggml_type memory_type,
bool vocab_only) { bool vocab_only,
bool embedding) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -594,6 +598,10 @@ static bool llama_model_load(
lctx.logits.reserve(lctx.model.hparams.n_ctx); lctx.logits.reserve(lctx.model.hparams.n_ctx);
if (embedding){
lctx.embedding.reserve(lctx.model.hparams.n_embd);
}
lctx.t_load_us = ggml_time_us() - t_start_us; lctx.t_load_us = ggml_time_us() - t_start_us;
return true; return true;
@ -1433,7 +1441,7 @@ struct llama_context * llama_init_from_file(
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) { if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx; delete ctx;
return nullptr; return nullptr;
@ -1508,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data(); return ctx->logits.data();
} }
float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data();
}
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) { if (token >= llama_n_vocab(ctx)) {
return nullptr; return nullptr;

View file

@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only
}; };
LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_context_params llama_context_default_params();
@ -109,6 +110,10 @@ extern "C" {
// Cols: n_vocab // Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx); LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);

View file

@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed; lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16; lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity; lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams); ctx = llama_init_from_file(params.model.c_str(), lparams);