Merge branch 'master' into example-readme
This commit is contained in:
commit
f30e41e0ef
3 changed files with 182 additions and 157 deletions
|
@ -187,5 +187,5 @@ These options provide extra functionality and customization when running the LLa
|
|||
- `-h, --help`: Display a help message showing all available options and their default values. This is particularly useful for checking the latest options and default values, as they can change frequently, and the information in this document may become outdated.
|
||||
- `--verbose-prompt`: Print the prompt before generating text.
|
||||
- `--mtest`: Test the model's functionality by running a series of tests to ensure it's working properly.
|
||||
- `--lora FNAME`: Apply a LoRA (Layer-wise Relevance Approximation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
||||
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
||||
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
|
||||
|
|
323
llama.cpp
323
llama.cpp
|
@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
|
|||
}
|
||||
}
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
|
||||
return ctx->model.kv_self.buf.addr;
|
||||
}
|
||||
|
||||
// Returns the size of the KV cache
|
||||
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
|
||||
return ctx->model.kv_self.buf.size;
|
||||
}
|
||||
|
||||
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
|
||||
return ctx->model.kv_self.n;
|
||||
}
|
||||
|
||||
// Sets the KV cache containing the current context for the model
|
||||
void llama_set_kv_cache(
|
||||
struct llama_context * ctx,
|
||||
const uint8_t * kv_cache,
|
||||
size_t n_size,
|
||||
int n_token_count) {
|
||||
// Make sure we have the same kv cache setup
|
||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
|
||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||
memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
|
||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
||||
ctx->model.kv_self.v->data = v_data;
|
||||
ctx->model.kv_self.n = n_token_count;
|
||||
#define LLAMA_MAX_RNG_STATE 64*1024
|
||||
|
||||
// Returns the size of the state
|
||||
size_t llama_get_state_size(struct llama_context * ctx) {
|
||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = LLAMA_MAX_RNG_STATE;
|
||||
const size_t s_logits_capacity = sizeof(size_t);
|
||||
const size_t s_logits_size = sizeof(size_t);
|
||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
||||
const size_t s_kv_size = sizeof(size_t);
|
||||
const size_t s_kv_ntok = sizeof(int);
|
||||
const size_t s_kv = ctx->model.kv_self.buf.size;
|
||||
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_logits_capacity
|
||||
+ s_logits_size
|
||||
+ s_logits
|
||||
+ s_embedding_size
|
||||
+ s_embedding
|
||||
+ s_kv_size
|
||||
+ s_kv_ntok
|
||||
+ s_kv
|
||||
);
|
||||
|
||||
return s_total;
|
||||
}
|
||||
|
||||
// Copies the state to the specified destination address
|
||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
||||
uint8_t * out = dest;
|
||||
|
||||
// copy rng
|
||||
{
|
||||
std::stringstream rng_ss;
|
||||
rng_ss << ctx->rng;
|
||||
|
||||
const size_t rng_size = rng_ss.str().size();
|
||||
char rng_buf[LLAMA_MAX_RNG_STATE];
|
||||
|
||||
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
|
||||
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
||||
|
||||
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
|
||||
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
|
||||
}
|
||||
|
||||
// copy logits
|
||||
{
|
||||
const size_t logits_cap = ctx->logits.capacity();
|
||||
const size_t logits_size = ctx->logits.size();
|
||||
|
||||
memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
|
||||
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
|
||||
|
||||
if (logits_size) {
|
||||
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
||||
}
|
||||
|
||||
out += logits_cap * sizeof(float);
|
||||
}
|
||||
|
||||
// copy embeddings
|
||||
{
|
||||
const size_t embedding_size = ctx->embedding.size();
|
||||
|
||||
memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
|
||||
|
||||
if (embedding_size) {
|
||||
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
|
||||
out += embedding_size * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
// copy kv cache
|
||||
{
|
||||
const size_t kv_size = ctx->model.kv_self.buf.size;
|
||||
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
||||
|
||||
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
||||
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
||||
|
||||
if (kv_size) {
|
||||
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = llama_get_state_size(ctx);
|
||||
|
||||
LLAMA_ASSERT(written == expected);
|
||||
|
||||
return written;
|
||||
}
|
||||
|
||||
// Sets the state reading from the specified source address
|
||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||
const uint8_t * in = src;
|
||||
|
||||
// set rng
|
||||
{
|
||||
size_t rng_size;
|
||||
char rng_buf[LLAMA_MAX_RNG_STATE];
|
||||
|
||||
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
|
||||
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
|
||||
|
||||
std::stringstream rng_ss;
|
||||
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
||||
rng_ss >> ctx->rng;
|
||||
|
||||
LLAMA_ASSERT(rng_ss.fail() == false);
|
||||
}
|
||||
|
||||
// set logits
|
||||
{
|
||||
size_t logits_cap;
|
||||
size_t logits_size;
|
||||
|
||||
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
|
||||
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
|
||||
|
||||
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
|
||||
|
||||
if (logits_size) {
|
||||
ctx->logits.resize(logits_size);
|
||||
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
||||
}
|
||||
|
||||
in += logits_cap * sizeof(float);
|
||||
}
|
||||
|
||||
// set embeddings
|
||||
{
|
||||
size_t embedding_size;
|
||||
|
||||
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
|
||||
|
||||
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
||||
|
||||
if (embedding_size) {
|
||||
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
||||
in += embedding_size * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
// set kv cache
|
||||
{
|
||||
size_t kv_size;
|
||||
int kv_ntok;
|
||||
|
||||
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
|
||||
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
||||
|
||||
if (kv_size) {
|
||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
||||
|
||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||
|
||||
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
|
||||
|
||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
||||
ctx->model.kv_self.v->data = v_data;
|
||||
|
||||
}
|
||||
|
||||
ctx->model.kv_self.n = kv_ntok;
|
||||
}
|
||||
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = llama_get_state_size(ctx);
|
||||
|
||||
LLAMA_ASSERT(nread == expected);
|
||||
|
||||
return nread;
|
||||
}
|
||||
|
||||
int llama_eval(
|
||||
|
@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
|
|||
return ctx->model.tensors_by_name;
|
||||
}
|
||||
|
||||
// Returns the size of the state
|
||||
size_t llama_get_state_size(struct llama_context * ctx) {
|
||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = 64*1024;
|
||||
const size_t s_logits_capacity = sizeof(size_t);
|
||||
const size_t s_logits_size = sizeof(size_t);
|
||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
||||
const size_t s_kv_size = sizeof(size_t);
|
||||
const size_t s_kv_ntok = sizeof(int);
|
||||
const size_t s_kv = llama_get_kv_cache_size(ctx);
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_logits_capacity
|
||||
+ s_logits_size
|
||||
+ s_logits
|
||||
+ s_embedding_size
|
||||
+ s_embedding
|
||||
+ s_kv_size
|
||||
+ s_kv_ntok
|
||||
+ s_kv
|
||||
);
|
||||
return s_total;
|
||||
}
|
||||
|
||||
// Copies the state to the specified destination address
|
||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
||||
std::stringstream rng_ss;
|
||||
rng_ss << ctx->rng;
|
||||
const size_t rng_size = rng_ss.str().size();
|
||||
char rng_buf[64*1024];
|
||||
memset(&rng_buf[0], 0, 64*1024);
|
||||
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
||||
const size_t logits_capacity = ctx->logits.capacity();
|
||||
const size_t logits_size = ctx->logits.size();
|
||||
const size_t embedding_size = ctx->embedding.size();
|
||||
const size_t kv_size = llama_get_kv_cache_size(ctx);
|
||||
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
||||
|
||||
uint8_t * out = dest;
|
||||
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
|
||||
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
|
||||
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
|
||||
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
|
||||
if (logits_size) {
|
||||
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
||||
}
|
||||
out += logits_capacity * sizeof(float);
|
||||
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
|
||||
if (embedding_size) {
|
||||
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
|
||||
}
|
||||
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
|
||||
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
|
||||
if (kv_size) {
|
||||
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
|
||||
}
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = llama_get_state_size(ctx);
|
||||
LLAMA_ASSERT(written == expected);
|
||||
return written;
|
||||
}
|
||||
|
||||
// Sets the state reading from the specified source address
|
||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||
size_t rng_size;
|
||||
char rng_buf[64*1024];
|
||||
std::stringstream rng_ss;
|
||||
|
||||
const uint8_t * in = src;
|
||||
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
|
||||
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
|
||||
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
||||
rng_ss >> ctx->rng;
|
||||
LLAMA_ASSERT(rng_ss.fail() == false);
|
||||
|
||||
size_t logits_capacity;
|
||||
size_t logits_size;
|
||||
size_t embedding_size;
|
||||
size_t kv_size;
|
||||
int kv_ntok;
|
||||
|
||||
memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
|
||||
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
|
||||
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
|
||||
if (logits_size) {
|
||||
ctx->logits.resize(logits_size);
|
||||
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
||||
}
|
||||
in += logits_capacity * sizeof(float);
|
||||
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
|
||||
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
||||
if (embedding_size) {
|
||||
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
||||
in += embedding_size * sizeof(float);
|
||||
}
|
||||
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
|
||||
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
|
||||
if (kv_size) {
|
||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
|
||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
||||
ctx->model.kv_self.v->data = v_data;
|
||||
in += kv_size;
|
||||
}
|
||||
ctx->model.kv_self.n = kv_ntok;
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = llama_get_state_size(ctx);
|
||||
LLAMA_ASSERT(nread == expected);
|
||||
return nread;
|
||||
}
|
||||
|
|
14
llama.h
14
llama.h
|
@ -112,23 +112,9 @@ extern "C" {
|
|||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
|
||||
|
||||
// Returns the size of the KV cache
|
||||
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
|
||||
|
||||
// Sets the KV cache containing the current context for the model
|
||||
LLAMA_API void llama_set_kv_cache(
|
||||
struct llama_context * ctx,
|
||||
const uint8_t * kv_cache,
|
||||
size_t n_size,
|
||||
int n_token_count);
|
||||
|
||||
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue