rename sequence state functions
This commit is contained in:
parent
a2b48b95f5
commit
c4443d7ad4
4 changed files with 13 additions and 13 deletions
|
@ -180,8 +180,8 @@ int main(int argc, char ** argv) {
|
||||||
// save seq 0 and load into seq 1
|
// save seq 0 and load into seq 1
|
||||||
{
|
{
|
||||||
// save kv of seq 0
|
// save kv of seq 0
|
||||||
std::vector<uint8_t> seq_store(llama_get_seq_size(ctx3, 0));
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||||
const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0);
|
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
|
||||||
if (ncopy != seq_store.size()) {
|
if (ncopy != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
|
@ -195,7 +195,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||||
|
|
||||||
// restore kv into seq 1
|
// restore kv into seq 1
|
||||||
const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1);
|
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
|
||||||
if (nset != seq_store.size()) {
|
if (nset != seq_store.size()) {
|
||||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
|
|
|
@ -1630,9 +1630,9 @@ struct server_context {
|
||||||
|
|
||||||
std::string filename = task.data["filename"];
|
std::string filename = task.data["filename"];
|
||||||
std::string filepath = task.data["filepath"];
|
std::string filepath = task.data["filepath"];
|
||||||
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
|
size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
|
||||||
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
||||||
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
|
size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
|
||||||
GGML_ASSERT(nwrite <= state_size);
|
GGML_ASSERT(nwrite <= state_size);
|
||||||
|
|
||||||
// write the cached token count of the slot->cache_tokens.size()
|
// write the cached token count of the slot->cache_tokens.size()
|
||||||
|
@ -1691,7 +1691,7 @@ struct server_context {
|
||||||
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||||
infile.close();
|
infile.close();
|
||||||
|
|
||||||
size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1);
|
size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
|
||||||
if (nread == 0) {
|
if (nread == 0) {
|
||||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -15059,7 +15059,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
||||||
// save the size of size_t as a uint32_t for safety check
|
// save the size of size_t as a uint32_t for safety check
|
||||||
const size_t size_t_size_size = sizeof(uint32_t);
|
const size_t size_t_size_size = sizeof(uint32_t);
|
||||||
|
|
||||||
|
@ -15109,7 +15109,7 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
||||||
return s_total;
|
return s_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
|
||||||
llama_data_buffer_context data_ctx(dst);
|
llama_data_buffer_context data_ctx(dst);
|
||||||
const auto& kv_self = ctx->kv_self;
|
const auto& kv_self = ctx->kv_self;
|
||||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
@ -15214,7 +15214,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
||||||
auto & kv_self = ctx->kv_self;
|
auto & kv_self = ctx->kv_self;
|
||||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
|
||||||
|
|
8
llama.h
8
llama.h
|
@ -623,20 +623,20 @@ extern "C" {
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
size_t n_token_count);
|
size_t n_token_count);
|
||||||
|
|
||||||
LLAMA_API size_t llama_get_seq_size(
|
LLAMA_API size_t llama_state_seq_get_size(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
LLAMA_API size_t llama_copy_seq_data(
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst,
|
uint8_t * dst,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence.
|
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence.
|
||||||
// Returns:
|
// Returns:
|
||||||
// - Positive: Ok
|
// - Positive: Ok
|
||||||
// - Zero: Failed to load
|
// - Zero: Failed to load
|
||||||
LLAMA_API size_t llama_set_seq_data(
|
LLAMA_API size_t llama_state_seq_set_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src,
|
const uint8_t * src,
|
||||||
llama_seq_id dest_seq_id);
|
llama_seq_id dest_seq_id);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue