Merge branch 'master' into mlock
This commit is contained in:
commit
a65f23342d
6 changed files with 109 additions and 35 deletions
|
@ -7,8 +7,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
|||
|
||||
**Hot topics:**
|
||||
|
||||
- [Roadmap (short-term)](https://github.com/ggerganov/llama.cpp/discussions/457)
|
||||
- New C-style API is now available: https://github.com/ggerganov/llama.cpp/pull/370
|
||||
- [Added Alpaca support](https://github.com/ggerganov/llama.cpp#instruction-mode-with-alpaca)
|
||||
- Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64
|
||||
- Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105
|
||||
|
||||
|
|
92
llama.cpp
92
llama.cpp
|
@ -103,8 +103,8 @@ struct llama_context {
|
|||
std::vector<float> logits;
|
||||
bool logits_all = false;
|
||||
|
||||
// work buffer for transformer evaluation
|
||||
std::vector<uint8_t> buf_eval;
|
||||
// input embedding (1-dimensional array: [n_embd])
|
||||
std::vector<float> embedding;
|
||||
};
|
||||
|
||||
struct llama_context_params llama_context_default_params() {
|
||||
|
@ -116,6 +116,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.logits_all =*/ false,
|
||||
/*.vocab_only =*/ false,
|
||||
/*.use_mlock =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
@ -131,8 +132,7 @@ static bool llama_model_load(
|
|||
int n_ctx,
|
||||
int n_parts,
|
||||
ggml_type memory_type,
|
||||
bool vocab_only,
|
||||
bool use_mlock) {
|
||||
bool vocab_only) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
@ -597,17 +597,6 @@ static bool llama_model_load(
|
|||
fin.close();
|
||||
}
|
||||
|
||||
if (use_mlock) {
|
||||
char *err;
|
||||
if (!ggml_mlock(ctx, &err)) {
|
||||
fprintf(stderr, "%s\n", err);
|
||||
free(err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
||||
|
||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||
|
||||
return true;
|
||||
|
@ -641,19 +630,27 @@ static bool llama_eval_internal(
|
|||
const int n_rot = hparams.n_embd/hparams.n_head;
|
||||
|
||||
auto & mem_per_token = lctx.mem_per_token;
|
||||
auto & buf_eval = lctx.buf_eval;
|
||||
|
||||
if (mem_per_token*(n_past + N + 16) > buf_eval.size()) {
|
||||
const size_t buf_size_new = 1.618*buf_eval.size();
|
||||
// TODO: fix this hardcoded size
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
|
||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
buf_eval.resize(buf_size_new);
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_eval.size(),
|
||||
/*.mem_buffer =*/ buf_eval.data(),
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
@ -797,6 +794,9 @@ static bool llama_eval_internal(
|
|||
inpL = cur;
|
||||
}
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
@ -805,6 +805,8 @@ static bool llama_eval_internal(
|
|||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm, inpL),
|
||||
inpL);
|
||||
|
||||
embeddings = inpL;
|
||||
}
|
||||
|
||||
// lm_head
|
||||
|
@ -827,6 +829,8 @@ static bool llama_eval_internal(
|
|||
//embd_w.resize(n_vocab*N);
|
||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||
|
||||
// extract logits
|
||||
{
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
if (lctx.logits_all) {
|
||||
|
@ -837,12 +841,20 @@ static bool llama_eval_internal(
|
|||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
}
|
||||
|
||||
if (N == 1) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/(n_past + N);
|
||||
}
|
||||
|
||||
//fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
|
||||
// extract embeddings
|
||||
if (lctx.embedding.size()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||
}
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
|
@ -1424,7 +1436,29 @@ struct llama_context * llama_init_from_file(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ctx->buf_eval.resize(512u*1024u*1024u);
|
||||
if (params.use_mlock) {
|
||||
char *err;
|
||||
if (!ggml_mlock(ctx->model.ctx, &err)) {
|
||||
fprintf(stderr, "%s\n", err);
|
||||
free(err);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// reserve memory for context buffers
|
||||
{
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
if (params.logits_all) {
|
||||
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||
} else {
|
||||
ctx->logits.reserve(hparams.n_ctx);
|
||||
}
|
||||
|
||||
if (params.embedding){
|
||||
ctx->embedding.reserve(hparams.n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
@ -1494,6 +1528,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||
return ctx->logits.data();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embedding.data();
|
||||
}
|
||||
|
||||
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
||||
if (token >= llama_n_vocab(ctx)) {
|
||||
return nullptr;
|
||||
|
|
5
llama.h
5
llama.h
|
@ -54,6 +54,7 @@ extern "C" {
|
|||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
@ -109,6 +110,10 @@ extern "C" {
|
|||
// Cols: n_vocab
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the input
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||
|
||||
|
|
23
main.cpp
23
main.cpp
|
@ -200,6 +200,7 @@ int main(int argc, char ** argv) {
|
|||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
lparams.embedding = params.embedding;
|
||||
|
||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||
|
||||
|
@ -293,6 +294,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
|
||||
int last_n_size = params.repeat_last_n;
|
||||
std::vector<llama_token> last_n_tokens(last_n_size);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
|
@ -325,6 +327,27 @@ int main(int argc, char ** argv) {
|
|||
// the first thing we will do is to output the prompt, so set color accordingly
|
||||
set_console_state(CONSOLE_STATE_PROMPT);
|
||||
|
||||
if (params.embedding){
|
||||
embd = embd_inp;
|
||||
|
||||
if (embd.size() > 0) {
|
||||
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
const auto embeddings = llama_get_embeddings(ctx);
|
||||
|
||||
// TODO: print / use the embeddings
|
||||
|
||||
if (params.use_color) {
|
||||
printf(ANSI_COLOR_RESET);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
while (remaining_tokens > 0 || params.interactive) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
|
|
|
@ -119,6 +119,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.model = argv[i];
|
||||
} else if (arg == "-i" || arg == "--interactive") {
|
||||
params.interactive = true;
|
||||
} else if (arg == "--embedding") {
|
||||
params.embedding = true;
|
||||
} else if (arg == "--interactive-start") {
|
||||
params.interactive = true;
|
||||
} else if (arg == "--interactive-first") {
|
||||
params.interactive_start = true;
|
||||
} else if (arg == "-ins" || arg == "--instruct") {
|
||||
|
|
4
utils.h
4
utils.h
|
@ -32,13 +32,17 @@ struct gpt_params {
|
|||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
|
||||
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool interactive = false; // interactive mode
|
||||
|
||||
bool embedding = false; // get only sentence embedding
|
||||
bool interactive_start = false; // wait for user input immediately
|
||||
|
||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool ignore_eos = false; // do not stop generating after eos
|
||||
bool perplexity = false; // compute perplexity over the prompt
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue