rename sequence state functions
This commit is contained in:
parent
a2b48b95f5
commit
c4443d7ad4
4 changed files with 13 additions and 13 deletions
|
@ -180,8 +180,8 @@ int main(int argc, char ** argv) {
|
|||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_get_seq_size(ctx3, 0));
|
||||
const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0);
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
llama_free(ctx3);
|
||||
|
@ -195,7 +195,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 1
|
||||
const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1);
|
||||
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
llama_free(ctx3);
|
||||
|
|
|
@ -1630,9 +1630,9 @@ struct server_context {
|
|||
|
||||
std::string filename = task.data["filename"];
|
||||
std::string filepath = task.data["filepath"];
|
||||
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
|
||||
size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
|
||||
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
||||
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
|
||||
size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
|
||||
GGML_ASSERT(nwrite <= state_size);
|
||||
|
||||
// write the cached token count of the slot->cache_tokens.size()
|
||||
|
@ -1691,7 +1691,7 @@ struct server_context {
|
|||
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
infile.close();
|
||||
|
||||
size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1);
|
||||
size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
|
||||
if (nread == 0) {
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
|
|
|
@ -15059,7 +15059,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
return true;
|
||||
}
|
||||
|
||||
size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
||||
// save the size of size_t as a uint32_t for safety check
|
||||
const size_t size_t_size_size = sizeof(uint32_t);
|
||||
|
||||
|
@ -15109,7 +15109,7 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
|||
return s_total;
|
||||
}
|
||||
|
||||
size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
|
||||
llama_data_buffer_context data_ctx(dst);
|
||||
const auto& kv_self = ctx->kv_self;
|
||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||
|
@ -15214,7 +15214,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_
|
|||
return data_ctx.get_size_written();
|
||||
}
|
||||
|
||||
size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
||||
auto & kv_self = ctx->kv_self;
|
||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||
|
||||
|
|
8
llama.h
8
llama.h
|
@ -623,20 +623,20 @@ extern "C" {
|
|||
const llama_token * tokens,
|
||||
size_t n_token_count);
|
||||
|
||||
LLAMA_API size_t llama_get_seq_size(
|
||||
LLAMA_API size_t llama_state_seq_get_size(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
LLAMA_API size_t llama_copy_seq_data(
|
||||
LLAMA_API size_t llama_state_seq_get_data(
|
||||
struct llama_context * ctx,
|
||||
uint8_t * dst,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence.
|
||||
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence.
|
||||
// Returns:
|
||||
// - Positive: Ok
|
||||
// - Zero: Failed to load
|
||||
LLAMA_API size_t llama_set_seq_data(
|
||||
LLAMA_API size_t llama_state_seq_set_data(
|
||||
struct llama_context * ctx,
|
||||
const uint8_t * src,
|
||||
llama_seq_id dest_seq_id);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue