rename sequence state functions

This commit is contained in:
Jan Boon 2024-03-28 22:10:04 +08:00
parent a2b48b95f5
commit c4443d7ad4
4 changed files with 13 additions and 13 deletions

View file

@ -180,8 +180,8 @@ int main(int argc, char ** argv) {
// save seq 0 and load into seq 1 // save seq 0 and load into seq 1
{ {
// save kv of seq 0 // save kv of seq 0
std::vector<uint8_t> seq_store(llama_get_seq_size(ctx3, 0)); std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0); const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
if (ncopy != seq_store.size()) { if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3); llama_free(ctx3);
@ -195,7 +195,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__); fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1 // restore kv into seq 1
const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1); const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
if (nset != seq_store.size()) { if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3); llama_free(ctx3);

View file

@ -1630,9 +1630,9 @@ struct server_context {
std::string filename = task.data["filename"]; std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"]; std::string filepath = task.data["filepath"];
size_t state_size = llama_get_seq_size(ctx, slot->id + 1); size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
GGML_ASSERT(nwrite <= state_size); GGML_ASSERT(nwrite <= state_size);
// write the cached token count of the slot->cache_tokens.size() // write the cached token count of the slot->cache_tokens.size()
@ -1691,7 +1691,7 @@ struct server_context {
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>()); std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
infile.close(); infile.close();
size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
if (nread == 0) { if (nread == 0) {
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break; break;

View file

@ -15059,7 +15059,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
return true; return true;
} }
size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
// save the size of size_t as a uint32_t for safety check // save the size of size_t as a uint32_t for safety check
const size_t size_t_size_size = sizeof(uint32_t); const size_t size_t_size_size = sizeof(uint32_t);
@ -15109,7 +15109,7 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
return s_total; return s_total;
} }
size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst); llama_data_buffer_context data_ctx(dst);
const auto& kv_self = ctx->kv_self; const auto& kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented GGML_ASSERT(!kv_self.recurrent); // not implemented
@ -15214,7 +15214,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_
return data_ctx.get_size_written(); return data_ctx.get_size_written();
} }
size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
auto & kv_self = ctx->kv_self; auto & kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented GGML_ASSERT(!kv_self.recurrent); // not implemented

View file

@ -623,20 +623,20 @@ extern "C" {
const llama_token * tokens, const llama_token * tokens,
size_t n_token_count); size_t n_token_count);
LLAMA_API size_t llama_get_seq_size( LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API size_t llama_copy_seq_data( LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx, struct llama_context * ctx,
uint8_t * dst, uint8_t * dst,
llama_seq_id seq_id); llama_seq_id seq_id);
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence. // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence.
// Returns: // Returns:
// - Positive: Ok // - Positive: Ok
// - Zero: Failed to load // - Zero: Failed to load
LLAMA_API size_t llama_set_seq_data( LLAMA_API size_t llama_state_seq_set_data(
struct llama_context * ctx, struct llama_context * ctx,
const uint8_t * src, const uint8_t * src,
llama_seq_id dest_seq_id); llama_seq_id dest_seq_id);