generalised copying state data to file or buffer

This commit is contained in:
l3utterfly 2023-08-03 17:02:29 +08:00
parent 6c798db041
commit 3cebd6e4b7
2 changed files with 39 additions and 113 deletions

View file

@ -149,6 +149,29 @@ struct llama_file {
} }
}; };
// llama_context_data
struct llama_data_context {
virtual void write(const void* src, size_t size) = 0;
virtual ~llama_data_context() = default;
};
struct llama_data_buffer_context : llama_data_context {
uint8_t* ptr;
llama_data_buffer_context(uint8_t* p) : ptr(p) {}
void write(const void* src, size_t size) override {
memcpy(ptr, src, size);
ptr += size;
}
};
struct llama_data_file_context : llama_data_context {
llama_file* file;
llama_data_file_context(llama_file* f) : file(f) {}
void write(const void* src, size_t size) override {
file->write_raw(src, size);
}
};
#if defined(_WIN32) #if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) { static std::string llama_format_win_err(DWORD err) {
LPSTR buf; LPSTR buf;

129
llama.cpp
View file

@ -3743,10 +3743,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
return s_total; return s_total;
} }
// Copies the state to the specified destination address // copy state data into either a buffer or file depending on the passed in context
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) {
uint8_t * out = dst;
// copy rng // copy rng
{ {
std::stringstream rng_ss; std::stringstream rng_ss;
@ -3758,8 +3756,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); data_ctx->write(&rng_size, sizeof(rng_size));
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
} }
// copy logits // copy logits
@ -3767,114 +3765,18 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
const size_t logits_cap = ctx->logits.capacity(); const size_t logits_cap = ctx->logits.capacity();
const size_t logits_size = ctx->logits.size(); const size_t logits_size = ctx->logits.size();
memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); data_ctx->write(&logits_cap, sizeof(logits_cap));
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); data_ctx->write(&logits_size, sizeof(logits_size));
if (logits_size) { if (logits_size) {
memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
}
out += logits_cap * sizeof(float);
}
// copy embeddings
{
const size_t embedding_size = ctx->embedding.size();
memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
if (embedding_size) {
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
out += embedding_size * sizeof(float);
}
}
// copy kv cache
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
const int n_ctx = hparams.n_ctx;
const size_t kv_size = kv_self.buf.size;
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{};
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kout3d->data = out;
out += ggml_nbytes(kout3d);
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
vout3d->data = out;
out += ggml_nbytes(vout3d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_ntok, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
}
}
const size_t written = out - dst;
const size_t max_size = llama_get_state_size(ctx);
LLAMA_ASSERT(written <= max_size);
return written;
}
// writes state data directly to file instead of copying it into a buffer
void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) {
// copy rng
{
std::stringstream rng_ss;
rng_ss << ctx->rng;
const size_t rng_size = rng_ss.str().size();
char rng_buf[LLAMA_MAX_RNG_STATE];
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
dst_file->write_raw(&rng_size, sizeof(rng_size));
dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE);
}
// copy logits
{
const size_t logits_cap = ctx->logits.capacity();
const size_t logits_size = ctx->logits.size();
dst_file->write_raw(&logits_cap, sizeof(logits_cap));
dst_file->write_raw(&logits_size, sizeof(logits_size));
if (logits_size) {
dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float));
} }
// If there is a gap between the size and the capacity, write padding // If there is a gap between the size and the capacity, write padding
size_t padding_size = (logits_cap - logits_size) * sizeof(float); size_t padding_size = (logits_cap - logits_size) * sizeof(float);
if (padding_size > 0) { if (padding_size > 0) {
std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
dst_file->write_raw(padding.data(), padding_size); data_ctx->write(padding.data(), padding_size);
} }
} }
@ -3882,10 +3784,10 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst
{ {
const size_t embedding_size = ctx->embedding.size(); const size_t embedding_size = ctx->embedding.size();
dst_file->write_raw(&embedding_size, sizeof(embedding_size)); data_ctx->write(&embedding_size, sizeof(embedding_size));
if (embedding_size) { if (embedding_size) {
dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
} }
} }
@ -3900,8 +3802,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst
const size_t kv_size = kv_self.buf.size; const size_t kv_size = kv_self.buf.size;
const int kv_ntok = llama_get_kv_cache_token_count(ctx); const int kv_ntok = llama_get_kv_cache_token_count(ctx);
dst_file->write_raw(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_size, sizeof(kv_size));
dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); data_ctx->write(&kv_ntok, sizeof(kv_ntok));
if (kv_size) { if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
@ -3933,8 +3835,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst
// our data is now in the kout3d_data and vout3d_data buffers // our data is now in the kout3d_data and vout3d_data buffers
// write them to file // write them to file
dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); data_ctx->write(kout3d_data.data(), kout3d_data.size());
dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); data_ctx->write(vout3d_data.data(), vout3d_data.size());
} }
} }
} }
@ -4122,7 +4024,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
file.write_raw(tokens, sizeof(llama_token) * n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving // save the context state using stream saving
llama_write_state_data_to_file(ctx, &file); llama_data_file_context data_ctx(&file);
llama_copy_state_data(ctx, &data_ctx);
return true; return true;
} }