llama : add api for getting/setting the complete state: rng, logits, embedding and kv_cache (#1105)

* reserve correct size for logits

* add functions to get and set the whole llama state:

including rng, logits, embedding and kv_cache

* remove unused variables

* remove trailing whitespace

* fix comment
This commit is contained in:
xaedes 2023-04-22 08:21:32 +02:00 committed by GitHub
parent 50cb666b8a
commit b6e7f9b09e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 1 deletions

12
llama.h
View file

@ -129,6 +129,18 @@ extern "C" {
size_t n_size,
int n_token_count);
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
// Set the state reading from the specified address
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls