llama : only save and restore used logits
for batch sizes of 512 this reduces save state in the best case by around 62 MB, which can be a lot if planning to save on each message to allow regenerating messages.
This commit is contained in:
parent
b9c60dec98
commit
5ee581473e
1 changed files with 5 additions and 17 deletions
22
llama.cpp
22
llama.cpp
|
@ -10145,8 +10145,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = LLAMA_MAX_RNG_STATE;
|
||||
const size_t s_logits_capacity = sizeof(size_t);
|
||||
const size_t s_logits_size = sizeof(size_t);
|
||||
// assume worst case for logits although only currently set ones are serialized
|
||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
||||
|
@ -10157,7 +10157,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_logits_capacity
|
||||
+ s_logits_size
|
||||
+ s_logits
|
||||
+ s_embedding_size
|
||||
|
@ -10241,22 +10240,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
|
||||
// copy logits
|
||||
{
|
||||
const size_t logits_cap = ctx->logits.capacity();
|
||||
const size_t logits_size = ctx->logits.size();
|
||||
|
||||
data_ctx->write(&logits_cap, sizeof(logits_cap));
|
||||
data_ctx->write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
|
||||
}
|
||||
|
||||
// If there is a gap between the size and the capacity, write padding
|
||||
size_t padding_size = (logits_cap - logits_size) * sizeof(float);
|
||||
if (padding_size > 0) {
|
||||
std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
|
||||
data_ctx->write(padding.data(), padding_size);
|
||||
}
|
||||
}
|
||||
|
||||
// copy embeddings
|
||||
|
@ -10380,20 +10370,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||
|
||||
// set logits
|
||||
{
|
||||
size_t logits_cap;
|
||||
size_t logits_size;
|
||||
|
||||
memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
|
||||
memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
|
||||
|
||||
GGML_ASSERT(ctx->logits.capacity() == logits_cap);
|
||||
GGML_ASSERT(ctx->logits.capacity() >= logits_size);
|
||||
|
||||
if (logits_size) {
|
||||
ctx->logits.resize(logits_size);
|
||||
memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
|
||||
}
|
||||
|
||||
inp += logits_cap * sizeof(float);
|
||||
memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
|
||||
inp += logits_size * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
// set embeddings
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue