llama : unified KV cache + batch inference API

This commit is contained in:
Georgi Gerganov 2023-09-18 10:08:22 +03:00
parent fad56936d4
commit d29e76937c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 315 additions and 236 deletions

34
llama.h
View file

@ -60,7 +60,20 @@ extern "C" {
struct llama_model;
struct llama_context;
typedef int llama_token;
typedef int32_t llama_pos;
typedef int32_t llama_token;
typedef int32_t llama_seq_id;
// data used for batch inference
typedef struct llama_batch {
uint32_t n_tokens;
// TODO: not sure about these consts - might just get in the way all the time with no benefit
const llama_token * token;
const float * embd;
const llama_pos * pos;
const llama_seq_id * seq_id;
} llama_seq;
enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2,
@ -289,8 +302,15 @@ extern "C" {
const char * path_base_model,
int n_threads);
//
// KV cache API
//
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future");
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
@ -319,7 +339,7 @@ extern "C" {
LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
uint32_t n_tokens,
int n_past,
int n_threads);
@ -327,16 +347,10 @@ extern "C" {
LLAMA_API int llama_eval_embd(
struct llama_context * ctx,
const float * embd,
int n_tokens,
uint32_t n_tokens,
int n_past,
int n_threads);
// Export a static computation graph for context of 511 and batch size of 1
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
// parameters here to keep things simple
// IMPORTANT: do not use for anything else other than debugging and testing!
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token