rename sequence state functions

This commit is contained in:
Jan Boon 2024-03-28 22:10:04 +08:00
parent a2b48b95f5
commit c4443d7ad4
4 changed files with 13 additions and 13 deletions

View file

@ -180,8 +180,8 @@ int main(int argc, char ** argv) {
// save seq 0 and load into seq 1
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_get_seq_size(ctx3, 0));
const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0);
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
@ -195,7 +195,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1
const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1);
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);

View file

@ -1630,9 +1630,9 @@ struct server_context {
std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"];
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
GGML_ASSERT(nwrite <= state_size);
// write the cached token count of the slot->cache_tokens.size()
@ -1691,7 +1691,7 @@ struct server_context {
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
infile.close();
size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1);
size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
if (nread == 0) {
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;

View file

@ -15059,7 +15059,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
return true;
}
size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
// save the size of size_t as a uint32_t for safety check
const size_t size_t_size_size = sizeof(uint32_t);
@ -15109,7 +15109,7 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
return s_total;
}
size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst);
const auto& kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented
@ -15214,7 +15214,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_
return data_ctx.get_size_written();
}
size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
auto & kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented

View file

@ -623,20 +623,20 @@ extern "C" {
const llama_token * tokens,
size_t n_token_count);
LLAMA_API size_t llama_get_seq_size(
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
LLAMA_API size_t llama_copy_seq_data(
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
llama_seq_id seq_id);
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence.
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence.
// Returns:
// - Positive: Ok
// - Zero: Failed to load
LLAMA_API size_t llama_set_seq_data(
LLAMA_API size_t llama_state_seq_set_data(
struct llama_context * ctx,
const uint8_t * src,
llama_seq_id dest_seq_id);